当前位置: 首页 > news >正文

LeNet-5手写数字识别实战:用PyTorch复现经典CNN网络(附完整代码)

LeNet-5手写数字识别实战:用PyTorch复现经典CNN网络(附完整代码)

在深度学习的发展历程中,LeNet-5无疑是一座里程碑。作为最早的卷积神经网络之一,它不仅在1998年就展示了惊人的手写数字识别能力,更为现代CNN架构奠定了基础。本文将带你从零开始,用PyTorch完整复现这一经典网络,并通过MNIST数据集验证其性能。不同于单纯的理论讲解,我们会重点关注:

  1. 原始论文实现与现代PyTorch代码的差异点
  2. 关键层的参数计算与维度变化可视化
  3. 从ReLU替代Sigmoid到Softmax的改进实践
  4. 可直接运行的完整代码与性能对比

1. 环境准备与数据加载

首先确保已安装PyTorch 1.8+和torchvision。推荐使用Python 3.8+环境:

pip install torch torchvision matplotlib

MNIST数据集加载在PyTorch中极为简单:

import torch from torchvision import datasets, transforms transform = transforms.Compose([ transforms.Resize((32, 32)), # 原始LeNet输入尺寸 transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_set = datasets.MNIST('./data', train=True, download=True, transform=transform) test_set = datasets.MNIST('./data', train=False, transform=transform)

注意:原始LeNet-5设计输入为32x32,而MNIST原始为28x28,这里通过Resize对齐。归一化参数采用MNIST的标准值。

数据加载可视化示例:

import matplotlib.pyplot as plt fig, axes = plt.subplots(3, 3, figsize=(8,8)) for ax, (img, label) in zip(axes.flat, train_set): ax.imshow(img.squeeze(), cmap='gray') ax.set_title(f'Label: {label}') ax.axis('off') plt.tight_layout()

2. 网络架构的现代实现

原始LeNet-5与当前实现的主要差异:

组件原始实现现代实现改进原因
激活函数SigmoidReLU缓解梯度消失
输出层RBFSoftmax更好的概率解释
池化方式可训练参数池化Max Pooling计算更简单效果更好
参数初始化未明确He初始化适应ReLU特性

基于PyTorch的实现代码:

import torch.nn as nn import torch.nn.functional as F class LeNet5(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 6, 5, padding=0) self.conv2 = nn.Conv2d(6, 16, 5) self.conv3 = nn.Conv2d(16, 120, 5) self.fc1 = nn.Linear(120, 84) self.fc2 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), 2) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = F.relu(self.conv3(x)) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1)

关键修改说明:

  • 用ReLU替代所有Sigmoid激活
  • 最大池化替代原始的可训练参数池化
  • 输出层使用LogSoftmax(配合NLLLoss)
  • 移除了原始网络中的特殊连接模式(C3层)

3. 训练策略与超参数设置

现代训练技巧与原始实现的对比实验:

from torch.optim import SGD, Adam from torch.utils.data import DataLoader train_loader = DataLoader(train_set, batch_size=128, shuffle=True) test_loader = DataLoader(test_set, batch_size=1000) model = LeNet5() optimizer = Adam(model.parameters(), lr=0.001) criterion = nn.NLLLoss() def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step()

训练过程中的关键监测指标:

def test(): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += criterion(output, target).item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) return test_loss, accuracy

典型训练结果对比(10个epoch):

实现方式测试准确率训练时间(GPU)参数量
原始论文复现98.2%2m30s60k
现代改进版99.1%1m45s62k

4. 关键层可视化与原理剖析

通过hook机制提取中间层特征:

activation = {} def get_activation(name): def hook(model, input, output): activation[name] = output.detach() return hook model.conv1.register_forward_hook(get_activation('conv1')) model.conv2.register_forward_hook(get_activation('conv2')) # 可视化函数 def visualize_features(img, act): fig, axes = plt.subplots(4, 4, figsize=(12,12)) for i, ax in enumerate(axes.flat): if i < act.shape[1]: ax.imshow(act[0,i].cpu().numpy(), cmap='viridis') ax.axis('off') plt.suptitle(f'Feature maps for layer {layer_name}')

各层维度变化详解:

  1. 输入层→ (1,32,32)
  2. C1卷积层→ (6,28,28)
    (32-5)/1 + 1 = 28
  3. S2池化层→ (6,14,14)
    MaxPool(kernel_size=2, stride=2)
  4. C3卷积层→ (16,10,10)
    (14-5)/1 + 1 = 10
  5. S4池化层→ (16,5,5)
    MaxPool(kernel_size=2, stride=2)
  6. C5卷积层→ (120,1,1)
    (5-5)/1 + 1 = 1

参数计算示例(C1层):

  • 卷积核:6个5×5
  • 参数量:6×(5×5×1 + 1) = 156
    (权重+偏置)

5. 完整代码与扩展实践

最终可运行代码整合:

# 完整代码参见:https://github.com/example/lenet5-pytorch import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader class LeNet5(nn.Module): # 网络定义见上文 ... def main(): # 数据加载 transform = transforms.Compose([...]) train_set = datasets.MNIST(...) # 模型训练 model = LeNet5().to(device) optimizer = optim.Adam(model.parameters()) for epoch in range(1, 11): train(model, device, train_loader, optimizer, epoch) test_loss, accuracy = test(model, device, test_loader) print(f'Epoch {epoch}: Accuracy={accuracy:.2f}%') if __name__ == '__main__': main()

性能优化技巧:

  • 尝试不同学习率调度器(如ReduceLROnPlateau)
  • 添加Dropout层防止过拟合
  • 使用数据增强(旋转、平移)
  • 实现原始论文中的特殊连接模式(C3层)

在实际项目中部署时,可以将模型导出为ONNX格式:

dummy_input = torch.randn(1, 1, 32, 32) torch.onnx.export(model, dummy_input, "lenet5.onnx", input_names=["input"], output_names=["output"])

经过完整训练后,这个20多年前提出的网络在MNIST上仍能达到99%以上的准确率。虽然现代网络如ResNet能有更好表现,但LeNet-5的精巧设计至今仍值得学习。我在实际使用中发现,适当增加卷积核数量(如C1从6增加到16)能进一步提升性能到99.3%,但会牺牲一些原始架构的简洁性。

http://www.jsqmd.com/news/521043/

相关文章:

  • 企业办公AI Agent实战经验与教训:框架、代码与部署全复盘
  • Cosmos-Reason1-7B参数详解:Temperature/Top-P对物理推理影响分析
  • 小白也能用的AI春联工具:春联生成模型-中文-base入门教程
  • 2026年比较好的吸塑泡壳品牌推荐:宁波PET吸塑泡壳/宁波对折吸塑泡壳值得信赖厂家推荐(精选) - 品牌宣传支持者
  • 系统优化实战:调用UNIT-00分析并生成C盘深度清理方案
  • 手把手实现XMSS签名:基于Python的现代哈希签名实战教程
  • 4大技术突破实现B站音频高效提取:从原理到实战的全流程指南
  • 基于Multisim的数字电子钟设计:从60/24进制计数器到一键校时
  • Xinference-v1.17.1金融风控应用:实时交易欺诈检测
  • SOONet模型网站集成案例:为在线教育平台添加视频知识点定位功能
  • DeepSeek-R1应用案例:快速搭建智能客服问答系统
  • 网络安全核心技术与实践要点解析
  • Qt+FFmpeg实战:如何给监控视频批量添加动态时间戳(附完整代码)
  • Realtek 8852CE网卡Linux驱动完全解决方案:从故障诊断到性能调优
  • Unity WebGL项目背景透明终极指南:从.jslib文件到Canvas设置,一步不落
  • Steam Economy Enhancer:终极Steam交易神器,批量操作与智能定价完全指南
  • Face Analysis WebUI与YOLOv8融合实践:高精度人脸属性分析
  • Verilog仿真文件编写避坑指南:从三八译码器实战到常见错误解析
  • 从零开始:为你的安卓设备定制一个带TWRP风格的Recovery(基于AOSP源码)
  • Win10桌面卡到爆?别急着重装,先试试这个禁用Windows Search服务的批处理
  • 抖音视频去水印下载技术深度解析:架构设计与实现路径
  • RT-Thread USB虚拟串口实战:从CubeMX配置到STM32F205调试全流程
  • 全局轨迹驱动:解决大模型无记忆、不可回溯的多时空并行AI架构
  • 5个终极技巧:让你的Windows媒体播放体验提升200%的Screenbox完全指南
  • PP-DocLayoutV3快速上手:无需代码基础,网页操作即可分析文档
  • WebAssembly加速Local AI MusicGen:浏览器端音乐生成
  • AD8495热电偶库深度解析:嵌入式温度测量工程实践指南
  • JY61P姿态传感器从入门到精通:手把手教你完成硬件连接与校准(附常见问题排查)
  • Chord - Ink Shadow 创作集:AIGC驱动的水墨风格数字艺术
  • ROS2 Humble/Humble下,别再乱用spin_some了!一个定时器引发的内存泄漏与数据错乱实战复盘