作者:来瓶霸王防脱发
项目地址:
https://github.com/IntptrMax/ResnetTrain原文地址:
https://blog.csdn.net/qq_30270773/article/details/143060313项目背景ResNet算法是一种十分经典的多分类模型。ResNet是一种残差神经网络, 在2015 年的ILSVRC(ImageNet Large Scale Visual Recognition Challenge)中取得了冠军。本文主要写的是其中的一种模型ResNet18,属于ResNet模型家族中最为基础的一种。这个模型目前算是深度学习中十分经典的模型,在很多场景中都有使用。不过这种模型在C#环境下训练的例子还是很少的,在常见的科技类网站上搜索大多为C#加载onnx模型进行推理,但训练的部分几乎没有任何资料。本文就来介绍如何在纯C#的环境下进行ResNet18网络的训练和推理。
算法实现本文主要介绍如何实现ResNet网络的训练和推理
模型介绍ResNet模型主要由Bottleneck,和4个Layers组成,其中每个Layer由是由两个BasicBlock组成,每个BasicBlock由Conv1,bn1,relu,conv2,bn2,downsample组成,所以整个网络的结构并不算复杂。结构如下图所示,更为具体的资料请参考
https://arxiv.org/abs/1512.03385
ResNet18网络训练的图像预处理技术十分常规,使用了Resize和Normalize的方法,并没有使用LetterBox。有兴趣的话也可以自行替换成LetterBox方法。核心代码如下:
torchvision.io.DefaultImager = new torchvision.io.SkiaImager();var transformers = torchvision.transforms.Compose([ torchvision.transforms.Resize(resizeHeight,resizeWidth), torchvision.transforms.CenterCrop(cropHeight,cropWidth), torchvision.transforms.Normalize(means, stdevs)]);Tensor imgTensor = torchvision.io.read_image(file) / 255.0f;imgTensor = transformers.call(imgTensor.unsqueeze(0));var labelTensor = torch.tensor(tagIndex, dtype: torch.int64);var tensorDataDic = new Dictionary<string, Tensor>();tensorDataDic.Add("tag", labelTensor);tensorDataDic.Add("img", imgTensor.squeeze(0));return tensorDataDic;训练过程训练过程是十分常规的方式。直接上代码。因为使用的数据只有5个分类,所以这里参数值使用了5,默认是1000,如果保持默认值,推理的结果会有1000项,取前5项做预测即可。本文中使用了Adam优化器,也可以尝试自行替换。
static void Train(){ var resnet = torchvision.models.resnet18(5).to(device, scalarType); var optimizer = torch.optim.Adam(resnet.parameters()); var cross_loss = new CrossEntropyLoss(); DataClass dataClass = new DataClass(path, ResizeWidth, ResizeHeight, CropWidth, CropHeight); var trainLoader = new DataLoader(dataClass, batchSize: BatchSize, shuffle: true, device: device, num_worker: Workers);for (int loop = 0; loop < Loop; loop++) { int step = 0; foreach (var trainBatch in trainLoader) { step++; var img = trainBatch["img"]; var tag = trainBatch["tag"]; optimizer.zero_grad(); Tensor re = resnet.forward(img); var loss = cross_loss.forward(re, tag); loss.backward(); optimizer.step(); var eval = loss.ToSingle(); GC.Collect(); Console.WriteLine("Loop:{0} Setp:{1} / {2} ; Loss:{3}", loop, step, trainLoader.Count, eval); } } resnet.save(modelName); Console.WriteLine("Train Done.");}推理过程推理过程更为简单,加载模型后即可进行推理,使用SoftMax方法可以求解到最终结果。
static void Test(){ string testPath = path; DataClass dataClass = new DataClass(testPath); var data = dataClass.GetTensor(0); var img = data["img"].unsqueeze(0).to(scalarType, device); var resnet = torchvision.models.resnet18(5,device: device).to(scalarType); resnet.load(modelName); resnet.eval(); Tensor re = (Tensor)resnet.forward(img); re = torch.softmax(re, 1); var (max, index) = re.max(1); Console.WriteLine("Predction:{0}\r\nScore:{1}\r\nTag:{2}", dataClass.GetTagNameByTag((int)index.ToInt64()), max.ToSingle(), dataClass.GetTagName(0));}实现效果通过训练+推理,可以得到如下效果,其中训练40轮,最终推理的准确度已经非常高了。

使用C#深度学习项目是很多人所希望的。不过在该方向上资料很少,开发难度大。常规使用C#进行深度学习项目的方法为使用Python训练,转为Onnx模型再用C#调用。目前我希望能够改变这一现象,希望能用纯C#平台进行训练和推理。这条路还很长,也很困难,希望有兴趣的读者能跟我一起让让C#的深度学习开发环境更为完善,以此能帮助到更多的人。
我在Github上已经将完整的代码发布了,项目地址为:
https://github.com/IntptrMax/ResnetTrain,期待你能在Github上送我一颗小星星。在我的Github里还GGMLSharp这个项目,这个项目也是C#平台下深度学习的开发包,希望能得到你的支持。
项目下载链接 https://download.csdn.net/download/qq_30270773/89897710