当前位置: 首页 > news >正文

从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率刷到99.5%以上的

从简单CNN到ResNet18:我是如何一步步把MNIST手写数字识别准确率刷到99.5%以上的

第一次用CNN跑MNIST时,看着测试集上98%的准确率还挺满意。直到在Kaggle上看到有人用相同数据集跑出99.5%+的成绩,才发现自己连入门级数据集的潜力都没榨干。这就像以为掌握了加减乘除就能解微积分——深度学习的水远比想象中深。经过两个月的反复实验,终于让模型突破了99.5%大关,整个过程堪称一部"调参侠"的进化史。

1. 基础CNN的瓶颈与突破

初始的CNN架构简单得可怜:两个卷积层夹着ReLU和最大池化,最后接全连接层。这个模型在10个epoch后就稳定在98.1%左右,典型的"早熟"表现。通过TensorBoard可视化发现,验证集准确率在第5轮后就几乎走平,说明模型容量根本不够。

第一批改进方案:

# 关键改进点代码示例 class EnhancedCNN(nn.Module): def __init__(self): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(1, 32, 5, padding=2), # 保持特征图尺寸 nn.BatchNorm2d(32), # 新增批归一化 nn.ReLU(inplace=True), nn.MaxPool2d(2)) self.block2 = nn.Sequential( nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), # 新增批归一化 nn.ReLU(inplace=True), nn.MaxPool2d(2)) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(0.5), # 新增Dropout nn.Linear(64*7*7, 10))

调整后的模型出现了几个明显变化:

  • 通道数从[10,20]扩展到[32,64],增强特征提取能力
  • 添加BatchNorm层后,学习率可以提升3倍而不发散
  • 引入Dropout后训练集准确率下降,但验证集提升0.6%

注意:BatchNorm一定要放在卷积层和激活函数之间,这个顺序错误会导致效果大打折扣

验证集准确率变化:

改进措施准确率提升训练时间增幅
基础CNN98.1%-
+BatchNorm+0.9%+15%
+通道扩展+0.7%+25%
+Dropout+0.6%可忽略

2. 数据增强的艺术

当模型在原始数据上达到98.7%后,我开始在数据层面寻找突破点。MNIST的简单特性决定了不能使用太激进的数据增强,经过反复测试,最终确定了最佳组合:

transform = transforms.Compose([ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # 微小平移 transforms.RandomRotation((-5, 5)), # 小角度旋转 transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])

数据增强效果对比:

  • 纯平移增强:+0.25%
  • 纯旋转增强:+0.18%
  • 组合增强:+0.42%
  • 添加弹性变形:-0.3%(过犹不及)

有趣的是,当增强幅度过大时(如旋转±15度),模型准确率反而下降。这是因为MNIST数字的形态特征比自然图像更敏感,过度变形会让"9"变得像"4"、"7"像"1"。

3. 学习率动态调整策略

固定学习率就像用固定速度爬山——平缓处太慢,陡峭处又容易翻车。尝试了三种动态调整方案:

  1. StepLR:每30个epoch乘以0.1
    scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
  2. Cosine退火
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
  3. ReduceLROnPlateau
    scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.5, patience=3)

实验结果:

  • StepLR在中期表现最好,但后期下降过快
  • Cosine退火整体平稳,但最高点不如ReduceLROnPlateau
  • ReduceLROnPlateau最终达到99.23%,是最佳选择

提示:监控验证集准确率而非训练损失作为调整依据,这样更可靠

4. 残差连接的降维打击

当传统CNN改进陷入瓶颈时,ResNet18带来了质的飞跃。但直接将ImageNet的架构用于MNIST会过犹不及,需要做针对性调整:

class MNISTResNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 16, 3, padding=1) # 输入通道改为1 self.bn1 = nn.BatchNorm2d(16) self.relu = nn.ReLU(inplace=True) # 简化版的残差块 self.layer1 = self._make_layer(16, 16, 2) self.layer2 = self._make_layer(16, 32, 2, stride=2) self.layer3 = self._make_layer(32, 64, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1,1)) self.fc = nn.Linear(64, 10) def _make_layer(self, in_channels, out_channels, blocks, stride=1): downsample = None if stride != 1 or in_channels != out_channels: downsample = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, stride), nn.BatchNorm2d(out_channels)) layers = [] layers.append(ResidualBlock(in_channels, out_channels, stride, downsample)) for _ in range(1, blocks): layers.append(ResidualBlock(out_channels, out_channels)) return nn.Sequential(*layers)

关键改进点:

  • 将原始ResNet18的4层残差块减为3层
  • 初始卷积核从7x7改为3x3
  • 最终平均池化层输出尺寸设为1x1
  • 通道数缩减为[16,32,64]以适应小图像

性能对比:

模型类型参数量测试准确率训练时间(epoch)
增强版CNN1.2M99.23%45min
简化ResNet180.8M99.47%68min
标准ResNet1811.2M99.31%2.5h

5. 突破99.5%的终极组合

最终的突破来自多个微创新的叠加效应:

  1. 权重初始化:改用Kaiming初始化

    def init_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') model.apply(init_weights)
  2. 优化器切换:从SGD改为RMSprop

    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)
  3. 标签平滑:缓解过拟合

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
  4. 测试时增强:对测试图像做5次随机变换取平均

    def predict(image): model.eval() outputs = [] for _ in range(5): aug_img = test_transform(image) # 包含随机变换 outputs.append(model(aug_img.unsqueeze(0))) return torch.mean(torch.stack(outputs), dim=0)

最终在测试集上的准确率曲线呈现出有趣的规律:每当引入一个新技巧,准确率就会上一个台阶,但提升幅度越来越小。从98%到99%相对容易,但从99%到99.5%需要付出十倍努力——这大概就是深度学习的边际效应吧。

http://www.jsqmd.com/news/827407/

相关文章:

  • .NET逆向工程新选择:dnSpyEx调试器与程序集编辑全解析
  • 别再乱写了!用Arduino玩转AT24C16 EEPROM,详解页写覆盖与跨页读写避坑
  • [017][web模块]基于计数器的接口幂等性与访问限流设计实战
  • 量子计算突破:超精细耦合常数计算新方法
  • 记录下我知道的去中心化网络协议
  • 5分钟快速上手:浏览器串口助手终极指南
  • 手把手教你用Proteus 8.15仿真STM32F103流水灯(STM32CubeMX + Keil MDK-ARM保姆级教程)
  • 2026年灵动女王脸多变风格排名 - myqiye
  • Linux I2C驱动调试踩坑记:MPU6050数据读取为何总报EIO错误?
  • 从入门到精通:trtexec命令行工具在TensorRT模型部署中的实战指南
  • ARM Cortex-A9 MPCore多核处理器架构与优化实践
  • 手把手教你用CMake和Ninja在Windows上编译免费Aseprite(附Skia配置避坑指南)
  • discli:命令行界面聚合框架,提升DevOps与云原生开发效率
  • 2分钟看完一周AI大事
  • 构建可信AI代理:从可观测性到安全沙箱的工程实践
  • ARM GIC中断控制器架构与寄存器编程详解
  • 2026年合同纠纷处理靠谱律所推荐,福峰所专业 - myqiye
  • 智能体“出逃”与管控:防止 AI Agent Harness Engineering 行为失范的技术
  • 量子计算性能评估:从基础指标到应用实践
  • Git分支管理工具branchlet:提升开发效率的轻量级命令行利器
  • 2026年物流公司口碑排名,哪个值得信赖? - 工业品牌热点
  • 构建个人智能数据仓:从信息孤岛到知识网络的实践指南
  • 【SCL实战】从冒泡排序到电梯调度:揭秘for循环在工业控制中的核心应用
  • Free NTFS for Mac终极指南:打破macOS读写限制的完整解决方案
  • 3个技巧让LaTeX参考文献自动符合GB/T 7714国标:告别手动排版烦恼
  • 从零搭建家庭实验室:开源项目ansh-info/homelab实践指南
  • 开源身份认证中心Casdoor:统一用户管理与单点登录实践指南
  • 2026年论文降AI攻略:亲测几款免费降AI工具,降低ai率,告别知乎维普AIGC率飘红 - 降AI实验室
  • 物流加工厂选购指南,上海楚基告诉你 - 工业品牌热点
  • [RKNN] 模型转换与推理实战:从YOLOX部署看API核心用法与性能调优