当前位置: 首页 > news >正文

Day40 预训练

预训练 = 用别人在超大数据集(ImageNet)上训练好的模型权重 → 拿来给我们用,少训练、效果好、速度快

import torch import torch.nn as nn import torch.optim as optim from torchvision import models, datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt # ===================== 1. 配置 ===================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 64 lr = 0.001 epochs = 5 num_classes = 10 # ===================== 2. 数据预处理(必须匹配预训练模型要求) ===================== transform = transforms.Compose([ transforms.Resize((224, 224)), # ResNet 标准输入 224 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet 均值 std=[0.229, 0.224, 0.225]) ]) # CIFAR10 数据集 train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # ===================== 3. 加载预训练模型(核心!) ===================== # 加载 ResNet18 预训练权重 model = models.resnet18(pretrained=True) # ===================== 4. 冻结卷积层(只训练最后的全连接层) ===================== for param in model.parameters(): param.requires_grad = False # 冻结所有层 # 替换最后一层全连接层(1000类 → 10类) in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) model = model.to(device) # ===================== 5. 损失函数 & 优化器 ===================== criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.fc.parameters(), lr=lr) # 只训练最后一层 # ===================== 6. 训练函数 ===================== def train(model, loader, criterion, optimizer, device): model.train() total_loss = 0 correct = 0 for data, target in loader: data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() avg_loss = total_loss / len(loader) acc = 100.0 * correct / len(loader.dataset) return avg_loss, acc # ===================== 7. 测试函数 ===================== def test(model, loader, criterion, device): model.eval() total_loss = 0 correct = 0 with torch.no_grad(): for data, target in loader: data, target = data.to(device), target.to(device) output = model(data) total_loss += criterion(output, target).item() pred = output.argmax(dim=1) correct += pred.eq(target).sum().item() avg_loss = total_loss / len(loader) acc = 100.0 * correct / len(loader.dataset) return avg_loss, acc # ===================== 8. 开始训练 ===================== print("===== 开始训练(仅训练最后一层)=====") for epoch in range(1, epochs+1): train_loss, train_acc = train(model, train_loader, criterion, optimizer, device) test_loss, test_acc = test(model, test_loader, criterion, device) print(f"Epoch {epoch:2d} | " f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | " f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%") # ===================== 9. 解冻全部层,进行微调(进阶) ===================== print("\n===== 解冻全部层,进行精细微调 =====") for param in model.parameters(): param.requires_grad = True # 解冻所有层 optimizer = optim.Adam(model.parameters(), lr=1e-5) # 小学习率 for epoch in range(1, 3): train_loss, train_acc = train(model, train_loader, criterion, optimizer, device) test_loss, test_acc = test(model, test_loader, criterion, device) print(f"Epoch {epoch:2d} | " f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | " f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

@浙大疏锦行

http://www.jsqmd.com/news/440705/

相关文章:

  • 做豆包广告,联系哪家公司比较靠谱 - 品牌2026
  • 2026年热门的负氧离子床垫公司推荐:佛山负氧离子床垫厂家实力哪家强 - 行业平台推荐
  • 豆包推广广告可以投放吗?哪家公司提供相关服务? - 品牌2026
  • 2026年优质的木工机械螺杆空压机公司推荐:激光切割螺杆空压机/橡胶机械螺杆空压机/汽车配件螺杆空压机实力工厂怎么选 - 行业平台推荐
  • 2026年评价高的鲜面条生产线公司推荐:大型面条生产线/商用鲜面条生产线专业制造厂家推荐 - 行业平台推荐
  • 2026年优秀的烤漆龙骨品牌推荐:烤漆龙骨品牌厂家哪家靠谱 - 行业平台推荐
  • 稳定报告基因细胞系(Stable Reporter Cell Line)是什么?从 HEK293/CHO 到信号通路读出的系统性理解
  • 2026年知名的电加热农用榨油机公司推荐:一体式农用榨油机/气压组合农用榨油机可靠供应商推荐 - 行业平台推荐
  • Git Git Hooks 自定义钩子
  • 【亲测免费】 如何使用QtCSV库进行CSV文件读写
  • Git Git LFS 使用
  • # 发散创新:用Python实现神经渲染中的光照估计与材质重建 在计算机图形学与深度学习
  • Git Git Notes 注释
  • 动态规划 | part12
  • 2026年比较好的集束电缆厂家推荐:铝合金电缆公司口碑哪家靠谱 - 行业平台推荐
  • Git Git Prune 清理无效引用
  • 告别高额订阅费!ONLYOFFICE——企业协作办公的明智之选
  • 代码随想录算法训练营第二天 | 长度最小的子数组、螺旋矩阵Ⅱ、区间和、
  • 2026年质量好的全钢制公寓床公司推荐:员工宿舍公寓床高口碑品牌推荐 - 行业平台推荐
  • 2026年优秀的双层宿舍铁床工厂推荐:宿舍铁床款式厂家选择指南 - 行业平台推荐
  • day1寻找除数
  • 2026年口碑好的模压TPE颗粒工厂推荐:吸塑脚垫TPE颗粒/TPE汽车脚垫颗粒精选厂家推荐 - 行业平台推荐
  • 【大数据毕设全套源码+文档】基于django+深度学习的经典名著推荐系统设计与实现(丰富项目+远程调试+讲解+定制)
  • 2026年可靠的橡胶辊品牌推荐:钢辊橡胶辊/烫金轮橡胶辊实力工厂怎么选 - 行业平台推荐
  • 2026年比较好的PC板温室大棚品牌推荐:锯齿温室大棚/养殖温室大棚厂家实力与用户口碑参考 - 行业平台推荐
  • 2026年质量好的透气三明治网布厂家推荐:鞋材三明治网布/涤纶三明治网布实力厂家如何选 - 行业平台推荐
  • 2026年可靠的无马弗网带炉厂家推荐:等温正火式网带炉优质供应商推荐 - 行业平台推荐
  • Chartbrew:一个开源的数据可视化平台 - 指南
  • 麒麟系统安装mysql8
  • Godot游戏练习01-第3节-多人场景创建