当前位置: 首页 > news >正文

深度学习入门:用PyTorch实现MNIST手写数字识别

以下是一个对新手既易懂又有价值的深度学习入门案例,关键在于将核心概念(如数据、模型、损失、优化)与直观、可运行的具体代码相结合。一个经典的入门案例是使用全连接神经网络(Fully Connected Neural Network, FCN)在MNIST手写数字数据集上进行图像分类。这个案例价值在于:

  1. 问题直观:识别0-9的手写数字,结果易于理解。
  2. 数据标准:MNIST是深度学习领域的“Hello World”,数据已预处理,便于聚焦模型本身。
  3. 涵盖完整流程:从数据加载、模型定义、训练到评估,覆盖了深度学习项目的基本闭环。
  4. 快速见效:模型相对简单,能在CPU上快速训练并看到明显效果。

以下我们将分步实现这个案例,代码中包含详细注释。

1. 环境准备与数据加载

首先,确保已安装PyTorch。然后,我们利用torchvision库加载MNIST数据集,它内置了下载和预处理功能。

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt # 1. 定义数据预处理转换 # ToTensor()将PIL图像或NumPy数组转换为PyTorch张量,并自动归一化像素值到[0,1]区间 # Normalize()进行标准化,使用MNIST数据集的均值和标准差,使数据分布更稳定,利于训练 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差 ]) # 2. 下载并加载训练集和测试集 train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 3. 创建数据加载器 (DataLoader) # DataLoader负责批量获取数据、打乱顺序、使用多线程加速数据加载 batch_size = 64 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 4. 可视化一批数据,检查是否正确加载 def imshow(img): # 反标准化,以便正常显示 img = img * 0.3081 + 0.1307 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) # 调整维度顺序为(H, W, C) plt.show() # 获取一个批次的数据 dataiter = iter(train_loader) images, labels = next(dataiter) print(f"图像张量形状: {images.shape}") # 应为 [batch_size, 1, 28, 28] print(f"标签形状: {labels.shape}") # 应为 [batch_size] # 显示一个批次中的前8张图片及其标签 fig, axes = plt.subplots(1, 8, figsize=(12, 3)) for i in range(8): ax = axes[i] img = images[i].squeeze() * 0.3081 + 0.1307 # 反标准化并去掉通道维度 ax.imshow(img.numpy(), cmap='gray') ax.set_title(f'Label: {labels[i].item()}') ax.axis('off') plt.show()

2. 构建神经网络模型

我们将构建一个简单的全连接神经网络。MNIST图像是28x28像素的灰度图,拉平后是一个784维的向量。模型结构为:输入层(784) -> 隐藏层(128) -> ReLU激活函数 -> 输出层(10,对应10个数字类别)。

class SimpleNN(nn.Module): """ 一个简单的全连接神经网络。 继承自nn.Module,这是PyTorch中所有神经网络模块的基类。 """ def __init__(self, input_size=784, hidden_size=128, num_classes=10): super(SimpleNN, self).__init__() # 必须调用父类的初始化方法 # 定义网络层 self.fc1 = nn.Linear(input_size, hidden_size) # 第一个全连接层 self.relu = nn.ReLU() # 非线性激活函数 self.fc2 = nn.Linear(hidden_size, num_classes) # 第二个全连接层(输出层) # 注意:输出层后没有Softmax,因为CrossEntropyLoss内部已包含Softmax def forward(self, x): """ 定义前向传播过程。 x: 输入张量,形状为 [batch_size, 1, 28, 28] """ # 将图像拉平成一维向量: [batch_size, 1*28*28] x = x.view(-1, 28*28) # 通过各层 x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x # 输出形状: [batch_size, 10] # 实例化模型、损失函数和优化器 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") model = SimpleNN().to(device) # 将模型参数移动到指定设备(GPU/CPU) criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,适用于多分类问题 optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器,自适应学习率

3. 训练模型

训练过程是深度学习的核心,即重复进行前向传播、计算损失、反向传播和参数更新。

num_epochs = 5 # 遍历整个训练集的次数 train_loss_history = [] train_acc_history = [] for epoch in range(num_epochs): model.train() # 设置为训练模式(影响某些层如Dropout、BatchNorm的行为) running_loss = 0.0 correct = 0 total = 0 for batch_idx, (images, labels) in enumerate(train_loader): # 1. 将数据移至设备 images, labels = images.to(device), labels.to(device) # 2. 前向传播:计算预测输出 outputs = model(images) # outputs shape: [64, 10] # 3. 计算损失 loss = criterion(outputs, labels) # 4. 反向传播与优化 optimizer.zero_grad() # 清空上一次迭代的梯度,防止累积 loss.backward() # 自动计算所有参数的梯度(自动微分) optimizer.step() # 根据梯度更新模型参数 # 5. 统计损失和准确率 running_loss += loss.item() _, predicted = outputs.max(1) # 获取预测类别(最大值的索引) total += labels.size(0) correct += predicted.eq(labels).sum().item() # 每处理100个batch打印一次信息 if (batch_idx + 1) % 100 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}') # 计算本epoch的平均损失和准确率 epoch_loss = running_loss / len(train_loader) epoch_acc = 100. * correct / total train_loss_history.append(epoch_loss) train_acc_history.append(epoch_acc) print(f'Epoch [{epoch+1}/{num_epochs}] 训练完成 -> 平均损失: {epoch_loss:.4f}, 准确率: {epoch_acc:.2f}%') print('-' * 50) # 可视化训练过程 plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(range(1, num_epochs+1), train_loss_history, marker='o') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training Loss') plt.grid(True) plt.subplot(1, 2, 2) plt.plot(range(1, num_epochs+1), train_acc_history, marker='o', color='orange') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.title('Training Accuracy') plt.grid(True) plt.tight_layout() plt.show()

4. 评估模型

训练完成后,必须在未见过的测试集上评估模型性能,以检验其泛化能力。

def evaluate_model(model, data_loader, device): """ 评估模型在给定数据加载器上的准确率。 """ model.eval() # 设置为评估模式(关闭Dropout等) correct = 0 total = 0 # 在评估阶段,不需要计算梯度,以节省内存和计算资源 with torch.no_grad(): for images, labels in data_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() accuracy = 100. * correct / total return accuracy # 在测试集上评估 test_accuracy = evaluate_model(model, test_loader, device) print(f'模型在测试集上的准确率: {test_accuracy:.2f}%') # 可视化一些测试样本的预测结果 model.eval() dataiter = iter(test_loader) images, labels = next(dataiter) images, labels = images.to(device), labels.to(device) with torch.no_grad(): outputs = model(images[:8]) _, predicted = outputs.max(1) # 将图像移回CPU以便用matplotlib显示 images_cpu = images.cpu() fig, axes = plt.subplots(2, 4, figsize=(12, 6)) axes = axes.flatten() for i in range(8): ax = axes[i] img = images_cpu[i].squeeze() * 0.3081 + 0.1307 ax.imshow(img.numpy(), cmap='gray') ax.set_title(f'True: {labels[i].item()}, Pred: {predicted[i].item()}', color='green' if labels[i]==predicted[i] else 'red') ax.axis('off') plt.suptitle('测试样本预测结果 (绿色正确,红色错误)', fontsize=14) plt.tight_layout() plt.show()

5. 案例总结与价值延伸

通过这个案例,新手可以掌握以下核心价值点

步骤掌握的核心概念与技能价值体现
数据加载Dataset,DataLoader, 数据预处理(transforms)理解如何将原始数据组织成模型可用的批量张量,这是所有深度学习项目的第一步。
模型定义nn.Module,nn.Linear,nn.ReLU,forward方法学会像搭积木一样构建神经网络,理解前向传播的流程。
训练循环损失函数(criterion)、优化器(optimizer)、.backward().step().zero_grad()理解深度学习如何通过反向传播梯度下降来学习,这是最核心的机制。
模型评估model.eval(),with torch.no_grad(),计算准确率学会如何公正地衡量模型性能,避免过拟合的自我欺骗。
结果可视化训练曲线、预测样本展示直观理解模型的学习过程和最终表现,是调试和展示的关键。

下一步的探索方向(为新手提供持续学习的价值)

  1. 提高准确率:尝试增加网络层数或隐藏单元数,观察模型容量的影响。
  2. 使用卷积神经网络(CNN):将本例中的SimpleNN替换为CNN(如使用nn.Conv2dnn.MaxPool2d)。CNN更适合图像数据,通常能获得更高的准确率(可轻松达到99%以上),这是理解归纳偏置特征提取的重要一步。
  3. 调试技巧:如果准确率很低(例如始终在10%左右,即随机猜测水平),检查数据预处理、模型输出维度、损失函数是否匹配任务(如二分类 vs 多分类)。
  4. 部署尝试:使用torch.jit.trace将训练好的模型保存为TracedModule,体验模型序列化,这是将模型投入实际应用的第一步。

这个案例从零开始,完整地走完了一个深度学习项目的基本流程,代码简洁但功能完整,新手在成功运行并看到>90%的测试准确率后,能获得强烈的正反馈,并以此为基石,向更复杂的模型和任务迈进。


参考来源

  • 深度学习如何入门
  • 如何从零开始掌握Practical_DL:深度学习新手必备的8周入门教程
  • 视频教程-深度学习与PyTorch入门实战教程-深度学习
  • 深度学习框架DeepLearning4J(DL4J)的安装及配置
  • dl_tutorials:深入浅出深度学习教程
  • PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】笔记
http://www.jsqmd.com/news/806508/

相关文章:

  • Redis++ TLS/SSL安全连接终极指南:保护你的Redis数据传输安全 [特殊字符]
  • 无传感器BLDC电机启动优化与RL78/G1F控制方案
  • K8sGPT:AI驱动的Kubernetes智能诊断与根因分析实践指南
  • Canopy框架:快速构建本地RAG应用的AI开发利器
  • React Native Actions Sheet源码解析:深入理解其架构与实现原理
  • API测试终极指南:构建高效自动化测试套件的10个关键步骤
  • 半导体创业IPO之路:从技术到市场的四大鸿沟与实战指南
  • 终极Passport.js与TypeScript集成指南:打造类型安全的Node.js身份验证系统
  • NocoBase v1.9.0 重磅发布:10大新功能让低代码开发更强大
  • Smart-SSO分布式部署踩坑实录:从POM依赖改写到Nginx配置的那些‘坑’
  • 如何在 Shell 脚本中解析带空格的命令行参数?
  • Linux Idle 调度器的 on_rq 状态:Idle 任务的运行队列管理
  • GEO优化行业主流服务商核心技术与服务能力盘点
  • 【老王架构指南】2026年库存账实不符怎么破?基于实在Agent的非侵入式盘点自动化落地全攻略
  • LLPlayer:基于本地AI的智能语言学习视频播放器实战指南
  • 拓璞数控开启招股:拟募资17亿港元 5月20日上市 RBC高瓴博裕加持
  • 深度定制游戏模型系统:3DMigoto架构解析与性能优化方案
  • 低压柜定制厂家,高压柜哪个牌子好,上海彬长电力设备、并网柜、箱变实力厂家,一文带你掌握 - 栗子测评
  • 基于Docker的AI智能体沙箱环境构建:open-harness项目实战指南
  • 中国移动2012年战略抉择:放弃iPhone补贴,押注TD-LTE自主标准
  • LLM Agent论文清单高效使用指南:从入门到精通的系统化路径
  • 基于多智能体系统的AI量化交易架构设计与实战解析
  • 从零构建生成式AI项目:RAG、智能体与微调实战指南
  • 从EE Times圣诞标题竞赛看技术社区创意运营与社群激活
  • 终极指南:如何在Android设备上运行Windows应用程序
  • 驭势科技开启招股:拟募资8.7亿港元 5月20日上市 雄安自动驾驶是基石投资者
  • Linux Idle 调度器的 arch_cpu_idle:体系结构相关的 Idle 实现
  • GraphMemory-IDE:专为图记忆应用设计的实时可视化开发环境
  • 从零构建专属AI桌面伙伴:my-neuro开源项目全解析与实战指南
  • Cursor编辑器历史链接管理工具:提升代码导航效率的智能解决方案