基于CNTK/C#实现MNIST【附源码】
文章目录
- 前言
- 一、环境搭建
- 二、MNIST代码解析
- 1.GPU/CPU的设置
- 2.参数变量的设置
- 3.关联模型构造
- 4.模型构建
- 5.评价指标
- 6.数据集加载
- 7.学习率设置
- 8.获取模型训练器
- 9.模型训练
- 10.模型保存
- 11.模型验证
- 三、效果展示
- 四、源码下载
前言
本文实现基于CNTK实现MNIST,并对其实现CNN以及MLP方法进行测试
一、环境搭建
因为本次使用C#,我们只需要对C#的一些依赖库进行整理放入到项目中即可,这里我把所需要的所有的dll都整理到win64文件夹,在C# 的项目工程上只需要引入Cntk.Core.Managed-2.7.dll、netstandard.dll即可。
如图:

二、MNIST代码解析
这里我们直接上代码并解释每行代码的意思和构建的思路和步骤
1.GPU/CPU的设置
//GPU设置
var device = DeviceDescriptor.GPUDevice(0);
2.参数变量的设置
//使用cnn还是使用mlp
bool useConvolution = true;
//模型内的名称
var featureStreamName = "features";
var labelsStreamName = "labels";
var classifierName = "classifierOutput";
Function classifierOutput;//输入和输出定义
int[] imageDim = useConvolution ? new int[] { 28, 28, 1 } : new int[] { 784 };
int imageSize = 28 * 28;
int numClasses = 10;//设置模型的保存名称
string modelFile = useConvolution ? "./MNIST_CNN.model" : "./MNIST_MLP.model";
3.关联模型构造
//配置网络层的对应关系:features -> 28 * 28 labelsStreamName -> numClasses
IList<StreamConfiguration> streamConfigurations = new StreamConfiguration[]
{ new StreamConfiguration(featureStreamName, imageSize), new StreamConfiguration(labelsStreamName, numClasses) };
//定义输入变量,输出变量
var input = CNTKLib.InputVariable(imageDim, DataType.Float, featureStreamName);
var labels = CNTKLib.InputVariable(new int[] { numClasses }, DataType.Float, labelsStreamName);
4.模型构建
if (useConvolution) //cnn
{var scaledInput = CNTKLib.ElementTimes(Constant.Scalar<float>(0.00390625f, device), input);//构建CNN网络结构classifierOutput = CreateConvolutionalNeuralNetwork(scaledInput, numClasses, device, classifierName);
}
else //mlp
{int hiddenLayerDim = 200; //mlp的隐藏节点//构建MLP结构var scaledInput = CNTKLib.ElementTimes(Constant.Scalar<float>(0.00390625f, device), input);classifierOutput = CreateMLPClassifier(device, numClasses, hiddenLayerDim, scaledInput, classifierName);
}
static Function CreateConvolutionalNeuralNetwork(Variable features, int outDims, DeviceDescriptor device, string classifierName)
{//CNN网络结构的构建//初始化卷积层的参数int kernelWidth1 = 3, kernelHeight1 = 3, numInputChannels1 = 1, outFeatureMapCount1 = 4;int hStride1 = 2, vStride1 = 2;int poolingWindowWidth1 = 3, poolingWindowHeight1 = 3;// 28x28x1 -> 14x14x4 卷积+激活函数Function pooling1 = ConvolutionWithMaxPooling(features, device, kernelWidth1, kernelHeight1,numInputChannels1, outFeatureMapCount1, hStride1, vStride1, poolingWindowWidth1, poolingWindowHeight1);//初始化卷积层的参数int kernelWidth2 = 3, kernelHeight2 = 3, numInputChannels2 = outFeatureMapCount1, outFeatureMapCount2 = 8;int hStride2 = 2, vStride2 = 2;int poolingWindowWidth2 = 3, poolingWindowHeight2 = 3;// 14x14x4 -> 7x7x8 卷积+激活函数Function pooling2 = ConvolutionWithMaxPooling(pooling1, device, kernelWidth2, kernelHeight2,numInputChannels2, outFeatureMapCount2, hStride2, vStride2, poolingWindowWidth2, poolingWindowHeight2);//Dense层设计Function denseLayer = Dense(pooling2, outDims, device, Activation.None, classifierName);return denseLayer;
}
5.评价指标
//出模型后求损失
var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(classifierOutput), labels, "lossFunction");
//出模型后求准确率
var prediction = CNTKLib.ClassificationError(new Variable(classifierOutput), labels, "classificationError");
6.数据集加载
//读取文件,数据集的读取,并确定好输入输出
var minibatchSource = MinibatchSource.TextFormatMinibatchSource("./mnist_data/MNIST_Train_cntk_text.txt", streamConfigurations, MinibatchSource.InfinitelyRepeat);
var featureStreamInfo = minibatchSource.StreamInfo(featureStreamName);
var labelStreamInfo = minibatchSource.StreamInfo(labelsStreamName);
7.学习率设置
//学习率的设置:SGD
TrainingParameterScheduleDouble learningRatePerSample = new TrainingParameterScheduleDouble(0.003125, 1);
IList<Learner> parameterLearners = new List<Learner>() { Learner.SGDLearner(classifierOutput.Parameters(), learningRatePerSample) };
8.获取模型训练器
var trainer = Trainer.CreateTrainer(classifierOutput, trainingLoss, prediction, parameterLearners);
9.模型训练
//开始训练
const uint minibatchSize = 64;
int outputFrequencyInMinibatches = 20, i = 0;
int epochs = 10000;
while (epochs > 0)
{var minibatchData = minibatchSource.GetNextMinibatch(minibatchSize, device);//定义好输入数据var arguments = new Dictionary<Variable, MinibatchData>{{ input, minibatchData[featureStreamInfo] },{ labels, minibatchData[labelStreamInfo] }};//训练入口trainer.TrainMinibatch(arguments, device);//输出训练结果PrintTrainingProgress(trainer, i++, outputFrequencyInMinibatches);if (MiniBatchDataIsSweepEnd(minibatchData.Values)){epochs--;}
}
10.模型保存
classifierOutput.Save(modelFile);
11.模型验证
// 验证模型
var minibatchSourceNewModel = MinibatchSource.TextFormatMinibatchSource("./mnist_data/MNIST_Test_cntk_text.txt", streamConfigurations, MinibatchSource.FullDataSweep);
ValidateModelWithMinibatchSource(modelFile, minibatchSourceNewModel, imageDim, numClasses, featureStreamName, labelsStreamName, classifierName, device);
三、效果展示

四、源码下载
源码下载
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
