当前位置: 首页 > news >正文

技术博文:基于 PyTorch 实现经典 LeNet-5 手写数字识别

技术博文:基于 PyTorch 实现经典 LeNet-5 手写数字识别

一、前言

LeNet-5 是深度学习入门最经典的卷积神经网络,由 Yann LeCun 提出,专门用于手写数字识别(MNIST 数据集)。本文将完整讲解模型结构、核心层参数、代码实现、训练与评估,新手也能从零跑通。


二、LeNet-5 模型详解

1. 模型整体结构

LeNet-5 一共5 层核心结构

  • 3 个卷积层(Conv2d):提取图像特征
  • 2 个全连接层(Linear):分类决策
  • 配合ReLU 激活+MaxPool 池化

适配 MNIST 数据:输入为1×28×28(单通道灰度图),输出为10 分类(0~9)

2. 核心层参数解释(必懂)

(1)卷积层nn.Conv2d

作用:对图像进行特征提取,捕捉边缘、纹理等信息。

python

运行

nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
  • in_channels:输入通道数(灰度图 = 1,彩色图 = 3)
  • out_channels:输出通道数(提取多少种特征)
  • kernel_size:卷积核大小(常用 3×3、5×5)
  • stride:步长(卷积核移动距离)
  • padding:填充(保持图像尺寸不缩小)

LeNet 中卷积层:

plaintext

conv1: 1→6,5×5,padding=2 (尺寸不变) conv2: 6→16,5×5 (尺寸缩小) conv3:16→120,5×5 (压缩为 1×1 特征)

(2)全连接层nn.Linear

作用:将高维特征映射为类别概率,完成最终分类。

python

运行

nn.Linear(in_features, out_features)
  • in_features:输入特征数量
  • out_features:输出特征数量(分类数)

LeNet 中全连接层:

plaintext

fc4: 120 → 84 fc5: 84 → 10 (输出 0~9 十个数字)

三、完整实现过程

3.1 模型构建(LeNet-5)

python

运行

import torch import torch.nn as nn class Lenet5(nn.Module): def __init__(self): super(Lenet5, self).__init__() # 第1层:卷积 + 激活 + 池化 self.conv1 = nn.Conv2d(1, 6, 5, 1, 2) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(2) # 第2层:卷积 + 激活 + 池化 self.conv2 = nn.Conv2d(6, 16, 5) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(2) # 第3层:卷积 + 激活 self.conv3 = nn.Conv2d(16, 120, 5) self.relu3 = nn.ReLU() # 第4层:全连接 self.fc4 = nn.Linear(120, 84) self.relu4 = nn.ReLU() # 第5层:输出层 self.fc5 = nn.Linear(84, 10) def forward(self, x): # 前向传播 x = self.pool1(self.relu1(self.conv1(x))) x = self.pool2(self.relu2(self.conv2(x))) x = self.relu3(self.conv3(x)) x = x.view(-1, 120) # 展平 x = self.relu4(self.fc4(x)) x = self.fc5(x) return x

3.2 数据加载(MNIST)

from torchvision.transforms import Compose, ToTensor from torch.utils.data import DataLoader from torchvision.datasets import MNIST def load_minist(root="./mnist", batch_size=64): # 数据转换:图片 → 张量 transform = Compose([ToTensor()]) # 训练集 + 测试集 ds_train = MNIST(root, train=True, download=True, transform=transform) ds_valid = MNIST(root, train=False, download=True, transform=transform) # 构建批次加载器 loader_train = DataLoader(ds_train, shuffle=True, batch_size=batch_size) loader_valid = DataLoader(ds_valid, shuffle=False, batch_size=1000) return loader_train, loader_valid

3.3 模型训练

def train_one(model, loader, loss_fun, optimizer, device="cpu"): model.train() model.to(device) for x, y in loader: optimizer.zero_grad() x, y = x.to(device), y.to(device) y_ = model(x) loss = loss_fun(y_, y) loss.backward() optimizer.step()

3.4 模型评估

@torch.no_grad() def evaluate(model, loader, loss_fun, device="cpu"): model.eval() model.to(device) total = 0 correct = 0 total_loss = 0 for x, y in loader: x, y = x.to(device), y.to(device) y_ = model(x) loss = loss_fun(y_, y) total_loss += loss.item() _, pred = torch.max(y_, dim=1) correct += (pred == y).sum().item() total += len(x) acc = correct / total * 100 print(f"测试损失:{total_loss:.6f},准确率:{acc:.2f}%")

3.5 主训练流程

def train(epochs=5, lr=0.001): model = Lenet5() loader_train, loader_valid = load_minist() loss_fun = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) for epoch in range(epochs): print(f"======== 第 {epoch+1} 轮 ========") train_one(model, loader_train, loss_fun, optimizer) evaluate(model, loader_valid, loss_fun) torch.save(model.state_dict(), "lenet5.pth") if __name__ == "__main__": train(epochs=5)

四、运行结果

训练 5 轮左右,测试集准确率可达 98%~99%。输出示例:

plaintext

======== 第 5 轮 ======== 测试损失:8.2645,准确率:98.74%

五、总结

  1. LeNet-5:3 卷积 + 2 全连接,手写数字识别经典模型
  2. Conv2d:提取图像特征,控制通道与尺寸
  3. Linear:分类输出,映射为 10 个数字类别
  4. 训练流程:数据加载 → 前向传播 → 计算损失 → 反向更新 → 评估
  5. 效果:MNIST 上轻松达到 98%+ 准确率
http://www.jsqmd.com/news/778277/

相关文章:

  • 2026年郑州汽车贴膜行业横向测评:5家主流门店深度对比 - 贴膜攒钱买霍希
  • gh_mirrors/in/invoice错误排查手册:常见问题与解决方案大全
  • DeepWay冲刺港股:年营收近40亿亏6.5亿 刚融资超3亿美元 百度与中东资本加持
  • AI原生代码审查知识库BeforeMerge:结构化规则赋能高效开发
  • Unity中解决Windows构建可寻址捆包后,程序加载时提示‘build target is 13’(对应安卓)出错问题解决方案
  • Glowby OSS:本地化AI编码代理工作流,实现生产就绪代码精炼
  • 利用 Taotoken 多模型能力为智能体应用提供稳定后端
  • 调频连续波 (FMCW) 雷达(一)距离测量
  • 油猴简书净化 - 冷夜
  • 提示工程实战指南:从核心原则到高级应用场景解析
  • YOLO训练翻车实录:从‘dog’和‘man’数据集到工业缺陷检测的实战避坑指南
  • Armv9-A架构扩展与嵌入式追踪技术解析
  • AI 内容导出乱、格式崩、公式变?我开发了这只鸭子帮我全解决了(三)** AI导出鸭 专写学生篇:从课堂笔记到毕业论文,AI 导出的那些坑
  • 基于SwiftUI与Combine的AR眼镜AI语音助手开发实战
  • 企业边缘计算设备INA1607:硬件架构与应用解析
  • 2026 年郑州首选:百莱创汽车贴膜工厂店靠谱揭秘 - 贴膜攒钱买霍希
  • 机器人通信的通信渠道
  • AI 内容导出乱、格式崩、公式变?我开发了这只鸭子帮我全解决了(五)** AI导出鸭 专写开发者篇:技术文档、代码导出、API文档,那些细节决定成败
  • 2026宁波婚纱摄影口碑排名:从客户真实评价数据,看宁波婚纱照哪家好 - charlieruizvin
  • Z-Image开源工具用户反馈实录:AI工程师如何用Z-Image-LM提升调试效率3倍
  • 从OpenClaw到Bramble:构建可破解、安全可控的AI代理框架实践
  • 别再写流水账了!用这个在线电影管理系统用例规约模板,3分钟搞定核心业务逻辑
  • CTFshow文件上传刷题
  • TypeORM游标分页库实战:解决大数据量分页的性能与一致性难题
  • 国内CNAS检测机构排行:权威合规与服务能力对比 - 奔跑123
  • AI设计:零基础用稿定设计+AI提示词快速生成技术封面与海报
  • 基于MCP协议构建本地AI文档解析服务器:rendoc-mcp-server实战指南
  • Chaterm:AI原生终端如何重塑运维工作流与团队协作
  • Vue+React混合架构实战:构建AI地图搜索与地理CRM应用
  • 从混淆矩阵到AUC:5分钟搞懂P-R曲线和ROC曲线的区别与联系