当前位置: 首页 > news >正文

ResNet18在MNIST手写数字数据库上的深度学习网络识别及Matlab仿真实验研究

ResNet18深度学习网络的mnist手写数字数据库识别matlab仿真

MNIST手写数字识别算是深度学习界的"Hello World"了,不过这次咱们用ResNet18来整点不一样的。别看ResNet本来是给ImageNet设计的,拿来折腾下28x28的小图片还挺有意思。先说说数据准备这块,Matlab处理起来比Python其实更省心:

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath,... 'IncludeSubfolders',true,'LabelSource','foldernames'); [imdsTrain,imdsTest] = splitEachLabel(imds,0.8,'randomized');

这里要注意个坑,原始ResNet输入是224x224的RGB图。咱们得给灰度图加个戏——用augmentedImageDatastore强行拉伸尺寸,虽然有点暴力但效果还行:

inputSize = [224 224 3]; augImdsTrain = augmentedImageDatastore(inputSize,imdsTrain,'ColorPreprocessing','rgb'); augImdsTest = augmentedImageDatastore(inputSize,imdsTest,'ColorPreprocessing','rgb');

接下来构建网络骨架。Matlab自带的resnet18其实可以直接魔改,但为了展示原理,咱们手搓一个残差块:

function lgraph = addBasicBlock(lgraph, blockName, numFilters, stride, inputLayerName) conv1_name = [blockName '_conv1']; bn1_name = [blockName '_bn1']; conv2_name = [blockName '_conv2']; bn2_name = [blockName '_bn2']; add_name = [blockName '_add']; % 残差路径 lgraph = addLayers(lgraph, [ convolution2dLayer(3,numFilters,'Stride',stride,'Padding','same','Name',conv1_name) batchNormalizationLayer('Name',bn1_name) reluLayer('Name',[blockName '_relu1']) convolution2dLayer(3,numFilters,'Padding','same','Name',conv2_name) batchNormalizationLayer('Name',bn2_name) ]); % shortcut连接 if stride ~= 1 shortcut = [ convolution2dLayer(1,numFilters,'Stride',stride,'Name',[blockName '_shortcut_conv']) batchNormalizationLayer('Name',[blockName '_shortcut_bn']) ]; lgraph = addLayers(lgraph, shortcut); lgraph = connectLayers(lgraph, inputLayerName, [blockName '_shortcut_conv']); else lgraph = connectLayers(lgraph, inputLayerName, add_name+'/in2'); end % 合并残差 lgraph = addLayers(lgraph, additionLayer(2,'Name',add_name)); lgraph = connectLayers(lgraph, bn2_name, [add_name '/in1']); end

这个残差块实现有几个精妙之处:当stride不为1时需要1x1卷积调整维度,否则直接相加。注意Matlab的加法层要处理两个输入源的连接,这里用connectLayers手动指定连接关系比自动构建更靠谱。

ResNet18深度学习网络的mnist手写数字数据库识别matlab仿真

训练配置这块别照搬ImageNet那套,学习率得调小点:

options = trainingOptions('sgdm',... 'InitialLearnRate',0.1,... 'LearnRateSchedule','piecewise',... 'LearnRateDropPeriod',5,... 'MaxEpochs',15,... 'Shuffle','every-epoch',... 'Plots','training-progress',... 'ValidationData',augImdsTest);

跑完15个epoch基本能到99.2%左右的准确率。测试时有个小技巧,用classify函数直接输出预测结果:

[YPred,probs] = classify(net,augImdsTest); YTest = imdsTest.Labels; accuracy = sum(YPred == YTest)/numel(YTest)

最后画混淆矩阵的时候,建议用自定义颜色更直观:

cm = confusionchart(YTest, YPred); cm.Title = 'ResNet18在MNIST上的混淆矩阵'; cm.ColumnSummary = 'column-normalized'; cm.RowSummary = 'row-normalized'; cm.FontSize = 12;

整个过程跑下来发现,虽然用ResNet18处理MNIST有点杀鸡用牛刀,但残差连接确实能加速训练收敛。有意思的是把图片强行拉伸到224x224后,网络前几层的特征图会保留更多细节,这对识别边缘尖锐的手写数字反而有帮助。不过要注意全连接层最后别用默认的1000输出,记得改成10分类哦!

http://www.jsqmd.com/news/487665/

相关文章:

  • PyCharm界面介绍
  • 基于zxing生成二维码
  • 时序数据库选型指南:从架构演进看Apache IoTDB的工业级优势
  • map映射和哈希映射
  • 未来 5 年,对于程序员群体而言非AI 大模型莫属!
  • 鸿蒙中 卡片交互:message事件(三)
  • 工作总结-接口设计
  • 西门子smart 200 rtu方式通讯四台三菱E700变频器资料 硬件:smart plc...
  • ChatGPT 引言写作指南:从新手到高手的结构化方法
  • YOLO系列算法改进 | 主干改进篇 | 替换ParameterNet参数优先网络 | 利用动态卷积自适应调整卷积核,助力模型低光照下增强边缘响应 | CVPR 2024
  • 永磁同步电机矢量控制FOC仿真:id=0与MTPA两种控制策略的对比分析与参考文献
  • P2679 [NOIP 2015 提高组] 子串
  • 3-16午夜盘思
  • 深入探究:直流电机单双闭环调速系统仿真模型与参数优化设计报告
  • XSLT快速入门:XML转换全攻略
  • 【论文精读】CodeWMBench 揭示 AI 生成代码水印的残酷真相
  • AudioSeal Pixel Studio从零开始:Windows平台Anaconda环境完整配置流程
  • TB6612FNG直流电机驱动板原理图设计,已量产
  • 工业级隔离型RS485接口电路原理图设计,已量产
  • 孙珍妮AI形象生成镜像指南:Z-Image-Turbo LoRA模型安全加载与沙箱隔离配置
  • Cosmos-Reason1-7B企业应用:化工厂监控视频中识别泄漏源与扩散模拟建议
  • 探索COMSOL中的Merging off-gamma BIC计算
  • std::process::Command
  • 用M文件在Matlab 2019a中实现两电平三相SVPWM
  • 乐高兼容ESP32对讲机:模块化嵌入式音频通信设计
  • 旋转卡壳
  • 基于Simulink的固定频率滞环电流控制Boost变换器
  • 南北阁Nanbeige 4.1-3B行业方案:数据库课程设计智能辅导系统
  • HCIP第二次作业
  • YOLOv8训练Visidron小目标检测数据集及精度提升实践