当前位置: 首页 > news >正文

Python AI 与深度学习 - D2.MNIST 手写数字识别

# 导入所有需要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# -------------------------- 1. 配置超参数和设备 --------------------------
batch_size = 64  # 每次加载64个样本
epochs = 5       # 整个数据集训练5遍
lr = 0.001       # 学习率
# 设备配置:优先CPU(入门)
device = torch.device("cpu")# -------------------------- 2. 数据预处理和加载MNIST数据集 --------------------------
# 数据预处理:将图像转为Tensor+归一化(0-1之间)
transform = transforms.Compose([transforms.ToTensor(),  # 转为Tensor,形状从(28,28)→(1,28,28),值从0-255→0-1transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,固定值
])
# 加载训练集和测试集(自动下载,无需手动找数据)
train_dataset = datasets.MNIST(root='./data',  # 数据保存路径train=True,     # 训练集download=True,  # 自动下载transform=transform
)
test_dataset = datasets.MNIST(root='./data',train=False,    # 测试集download=True,transform=transform
)
# 用DataLoader批量加载数据,打乱顺序
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# -------------------------- 3. 定义简单的卷积神经网络(CNN) --------------------------
class MNIST_CNN(nn.Module):def __init__(self):super(MNIST_CNN, self).__init__()# 卷积层:输入1通道(灰度图),输出16通道,卷积核3x3,步长1,填充1self.conv1 = nn.Conv2d(1, 16, 3, 1, 1)self.relu = nn.ReLU()  # 激活函数self.pool = nn.MaxPool2d(2, 2)  # 池化层:2x2,步长2,缩小图像尺寸# 卷积层:输入16通道,输出32通道self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)# 全连接层:将特征展平后,映射到10个输出(0-9数字)self.fc1 = nn.Linear(32 * 7 * 7, 10)  # 28→池化2次→7,32*7*7是展平后的特征数def forward(self, x):# 前向传播:卷积→激活→池化→卷积→激活→池化→展平→全连接x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 32 * 7 * 7)  # 展平:(batch_size, 32,7,7)→(batch_size, 32*7*7)x = self.fc1(x)return x# 实例化模型,移到指定设备
model = MNIST_CNN().to(device)# -------------------------- 4. 定义损失函数和优化器 --------------------------
criterion = nn.CrossEntropyLoss()  # 交叉熵损失,适合分类任务
optimizer = optim.Adam(model.parameters(), lr=lr)  # Adam优化器,比SGD更快收敛# -------------------------- 5. 模型训练 --------------------------
print("开始训练MNIST模型...")
model.train()  # 模型进入训练模式
for epoch in range(epochs):running_loss = 0.0for i, (images, labels) in enumerate(train_loader):# 将数据移到指定设备images, labels = images.to(device), labels.to(device)# 1. 清空梯度(必须步骤,否则梯度累加)
        optimizer.zero_grad()# 2. 前向传播:输入图像,得到预测结果outputs = model(images)# 3. 计算损失:预测结果与真实标签的差距loss = criterion(outputs, labels)# 4. 反向传播:计算梯度
        loss.backward()# 5. 优化器更新参数
        optimizer.step()# 累计损失running_loss += loss.item()# 打印每个epoch的平均损失avg_loss = running_loss / len(train_loader)print(f"第{epoch+1}/{epochs}轮训练,平均损失:{avg_loss:.4f}")print("训练完成!")# -------------------------- 6. 模型测试(评估准确率) --------------------------
print("开始测试模型...")
model.eval()  # 模型进入测试模式,关闭Dropout/BN等
correct = 0
total = 0
# 测试时不需要计算梯度,节省内存
with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)# 取预测概率最大的类别作为预测结果_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 打印测试准确率
accuracy = 100 * correct / total
print(f"MNIST测试集准确率:{accuracy:.2f}%")  # 正常训练后准确率≥98%# -------------------------- 7. 可视化预测结果(可选) --------------------------
# 取测试集中的前5张图像,展示预测结果
images, labels = next(iter(test_loader))
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)# 绘制图像
plt.figure(figsize=(10, 5))
for i in range(5):plt.subplot(1, 5, i+1)plt.imshow(images[i].squeeze().numpy(), cmap='gray')  # 去掉通道维度,显示灰度图plt.title(f"真实:{labels[i].item()}\n预测:{predicted[i].item()}")plt.axis('off')
plt.show()# -------------------------- 8. 保存模型(后续可直接加载使用) --------------------------
torch.save(model.state_dict(), './mnist_cnn_model.pth')
print("模型已保存为:mnist_cnn_model.pth")# -------------------------- 加载模型(后续使用) --------------------------
# new_model = MNIST_CNN().to(device)
# new_model.load_state_dict(torch.load('./mnist_cnn_model.pth'))
# new_model.eval()  # 加载后必须进入测试模式

 

运行完成后,会得到 3 个关键有效输出,代表项目成功:

当前使用设备:cuda:0
开始训练MNIST模型...
第1/5轮训练,平均损失:0.2xx
第2/5轮训练,平均损失:0.0xx
...
训练完成!
开始测试模型...
MNIST测试集准确率:98.xx%
模型已保存为:mnist_cnn_model.pth

 

http://www.jsqmd.com/news/341464/

相关文章:

  • 简单理解:CAN 收发器 TJA1050 如何将来自微控制器的单端 TTL/CMOS 逻辑信号转换为 CAN 总线所需的差分信号。
  • 分布式淘客系统的配置中心设计:Nacos在多环境配置管理的应用
  • 京东e卡回收平台:快速兑现礼品卡竟如此简单! - 团团收购物卡回收
  • 2026全自动咖啡机选哪个牌子好 靠谱高性价比口碑品牌推荐 - 品牌2025
  • 2026年智能语音机器人厂商优选指南:上门服务、免费演示及合作流程 - 品牌2025
  • 2026年抗风卷帘门公司权威推荐:镂空卷帘门/防火卷帘门/防火门/pvc快速门/别墅车库门/堆积门/工业门/彩钢卷帘门/选择指南 - 优质品牌商家
  • 返利机器人的商品数据同步方案:API拉取与增量更新的技术实现
  • 2026家用睡眠仪推荐TOP榜,每款标注推荐指数 - 速递信息
  • 对接印度股票市场内容 (India api) 实时k线图表
  • 内网环境下,如何使用js处理大文件的目录结构上传?
  • Genie 3 震撼发布:AI 绘图或成“画笔”,但游戏灵魂仍由人类执笔!
  • 2026年优质客服系统厂商推荐,覆盖高性价比、智能语音与一体化服务 - 品牌2025
  • 淘客返利系统的CI/CD流水线搭建:Docker镜像构建与K8s部署实践
  • 2026年2月全自动贴袋机/免烫贴袋机/全自动贴兜机/免烫贴兜机/全自动开袋机/全自动开兜机行业TOP5服务商全景评估报告 - 2026年企业推荐榜
  • PAM-COMPOSITE 复合材料仿真数据导出操作手册(适配可视化工具专用)
  • 2026年靠谱的APP开发公司有哪些?基于多维度数据的客观盘点
  • 2026智能咖啡机推荐 哪家值得信赖口碑好性价比高服务好质量优 - 品牌2025
  • 淘宝返利软件的可观测性架构:Prometheus与Grafana监控体系搭建
  • 2026年优质客服系统厂商推荐,覆盖在线试用、智能应答与全渠道售后 - 品牌2025
  • 西门子PLC设备锁机程序探秘:S7 - 200cn与S7 - 200 smart的独特应用
  • 国内外市场占有率高、质量好且售后服务好的介电常数测定仪厂家推荐 - 品牌推荐大师1
  • 淘宝客返利系统的用户数据安全设计:脱敏存储与接口访问控制
  • 一天一个Python库:pygments - 强大的代码高亮和格式化工具
  • 避坑指南|2026年2月敏感肌护肤品终极测评:这些误区别踩,选对比选贵重要 - 速递信息
  • 淘客系统的佣金资金流处理:数据追溯与账户交易的安全机制
  • # 缓存与数据库的协调策略【缓存更新时机】
  • 2026医用级硅胶生产厂家推荐榜:三大标杆企业助力医疗设备精准化升级 - 速递信息
  • Opencv 学习笔记:提取轮廓中心点坐标(矩计算法)
  • 美通卡回收的实操图文指南 - 京回收小程序
  • 2026厂房洁净室工程怎么选?5家行业标杆企业值得关注 - 品牌2025