当前位置: 首页 > news >正文

P66实训题

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize
import matplotlib.pyplot as plt
import numpy as np

1. 数据加载与预处理

transform = Compose([
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

类别标签

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

查看数据集样本(可选)

def imshow(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

dataiter = iter(train_loader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images[:4]))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

2. 构建网络

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 4 * 4, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.3)

def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv2(x)))x = self.dropout(x)x = self.pool(torch.relu(self.conv3(x)))x = self.dropout(x)x = x.view(-1, 128 * 4 * 4)x = torch.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

net = Net()

3. 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

4. 训练网络

epochs = 30
train_losses = []
train_accs = []
test_losses = []
test_accs = []

for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
net.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss = running_loss / len(train_loader)
train_acc = 100. * correct / total
train_losses.append(train_loss)
train_accs.append(train_acc)

# 测试
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:outputs = net(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()
test_loss = test_loss / len(test_loader)
test_acc = 100. * correct / total
test_losses.append(test_loss)
test_accs.append(test_acc)print(f'Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}% | Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.2f}%')

print('Finished Training')

绘制训练曲线

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.title('Loss')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.legend()
plt.title('Accuracy')
plt.show()

5. 测试模型精度

net.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

print(f'测试集精度: {100. * correct / total:.2f}%')

查看各类别预测精度(可选)

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for inputs, labels in test_loader:
outputs = net(inputs)
_, predicted = outputs.max(1)
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1

for i in range(10):
print(f'类别 {classes[i]} 的精度: {100. * class_correct[i] / class_total[i]:.2f}%')

http://www.jsqmd.com/news/14311/

相关文章:

  • 非主流网站程序IndexNow添加方法
  • 卷积神经网络视频读书报告
  • C 语言 - 内存操作函数以及字符串操作函数解析
  • 以*this返回局部对象的两种情况
  • 2025.10.15
  • 2025秋_12
  • nginx-1.16.1-2.p01.ky10.sw_64.rpm 安装教程(详细步骤,适用于Kylin V10/申威SW64架构)
  • 第七章:C控制语句:分支和跳转
  • 感知节点@5@ ESP32+arduino+ 第三个程序FreeRTOS 上 LED灯显示 和 串口打印ASCII表
  • 【Azure App Service】App Service是否支持PHP的版本选择呢?
  • OAuth/OpenID Connect 渗透测试完全指南
  • Problem K. 置换环(The ICPC online 2025)思路解析 - tsunchi
  • Go 语言和 Tesseract OCR 识别英文数字验证码
  • Markdown转换为Word:Pandoc模板使用指南 - 实践
  • 2025年10月小程序开发公司最新推荐排行榜,小程序定制开发,电商小程序开发,预订服务小程序开发,活动报名小程序开发!
  • 数据结构-循环队列
  • C语言学习——键盘录入
  • 2025年10月软件开发公司最新推荐,软件定制开发,crm系统定制软件开发,管理系统软件开发,物联网软件开发公司推荐!
  • C语言学习——运算符的学习
  • 第十五篇
  • 数据结构-顺序栈
  • 实用指南:NXP - 用MCUXpresso IDE v25.6.136的工具链编译Smoothieware固件工程
  • Erlang 的英文数字验证码识别系统设计与实现
  • 使用Django从零开始构建一个个人博客系统 - 实践
  • 2025年磨床厂家TOP企业品牌推荐排行榜,平面磨床,外圆磨床,数控平面磨床,数控外圆磨床,7163平面磨床推荐这十家公司!
  • cifar10
  • [LangChain] 02. 模型接口
  • 摄像头调试
  • 软件工程作业-报告1 - 实践
  • WebGL学习及项目实战(第02期:绘制一个点)