当前位置: 首页 > news >正文

P66实训2

运行代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Normalize
import time

1. 数据加载(简化预处理)

transform = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

2. 简化模型结构

class SimpleModel(nn.Module):
def init(self):
super().init()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
self.flat = nn.Flatten()
self.fc = nn.Linear(16 * 16 * 16, 10) # 32/2/2=8?不对,32经过两次池化(22)后是8×8?哦,我之前算错了,重新来:32→池化后16→再池化后8,所以168*8=1024。抱歉之前的错误,这里修正。
self.fc = nn.Linear(16 * 8 * 8, 10)

def forward(self, x):x = self.relu(self.conv(x))x = self.pool(x)x = self.relu(self.conv(x))  # 再加一层卷积x = self.pool(x)x = self.flat(x)x = self.fc(x)return x

model = SimpleModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

3. 损失函数与优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

4. 简化训练(减少轮数)

epochs = 10
start_time = time.time()

for epoch in range(epochs):
model.train()
running_loss = 0.0
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, 训练损失: {running_loss/len(train_loader):.3f}")

print(f"训练耗时: {time.time()-start_time:.2f} 秒")

5. 简单评估

model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

print(f"测试准确率: {correct/total:.3f}")
测试结果图片
1760533418483

http://www.jsqmd.com/news/14325/

相关文章:

  • 《程序员的修炼之道:从小工到专家》阅读笔记
  • 关于Pytorch深度学习神经网络的读书报告
  • const int *p和int *const p快速区分
  • 二分图、拓扑与欧拉
  • Zhengrui #3346. DINO
  • Pytorch深度学习训练
  • P11894 「LAOI-9」Update
  • win10软实时设置 - 教程
  • 实用指南:Hunyuan3D-Omni:可控3D资产生成的统一框架
  • ZR 2025 NOIP 二十连测 Day 3
  • 实用指南:2025年9月个人工作生活总结
  • P14223 [ICPC 2024 Kunming I] 乐观向上
  • 别再用均值填充了!MICE算法教你正确处理缺失数据
  • P66实训题
  • 非主流网站程序IndexNow添加方法
  • 卷积神经网络视频读书报告
  • C 语言 - 内存操作函数以及字符串操作函数解析
  • 以*this返回局部对象的两种情况
  • 2025.10.15
  • 2025秋_12
  • nginx-1.16.1-2.p01.ky10.sw_64.rpm 安装教程(详细步骤,适用于Kylin V10/申威SW64架构)
  • 第七章:C控制语句:分支和跳转
  • 感知节点@5@ ESP32+arduino+ 第三个程序FreeRTOS 上 LED灯显示 和 串口打印ASCII表
  • 【Azure App Service】App Service是否支持PHP的版本选择呢?
  • OAuth/OpenID Connect 渗透测试完全指南
  • Problem K. 置换环(The ICPC online 2025)思路解析 - tsunchi
  • Go 语言和 Tesseract OCR 识别英文数字验证码
  • Markdown转换为Word:Pandoc模板使用指南 - 实践
  • 2025年10月小程序开发公司最新推荐排行榜,小程序定制开发,电商小程序开发,预订服务小程序开发,活动报名小程序开发!
  • 数据结构-循环队列