当前位置: 首页 > news >正文

PyTorch预训练模型加载实战:从.pth文件到迁移学习避坑指南

1. 从零开始加载.pth文件的完整流程

第一次用PyTorch加载预训练模型时,我盯着那个.pth文件发呆了半小时——明明按照官方文档写的代码,却总是报各种奇怪的错误。后来才发现,从下载模型到加载权重,每个环节都藏着不少坑。下面我就用SqueezeNet为例,带你完整走一遍这个流程。

先说说最常见的网络下载问题。当你运行model = models.squeezenet1_1(pretrained=True)时,程序会尝试从PyTorch服务器下载模型文件。但在国内环境下,十次有九次会碰到这样的报错:

requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。', None, 10060, None))

这时候别急着翻墙(注意:所有操作都应在合法合规前提下进行),我有更简单的解决方案。仔细观察报错信息,会发现类似这样的下载链接:

Downloading: "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth"

把这个链接复制到浏览器,如果打不开,试试去掉https://前缀,直接访问download.pytorch.org/models/squeezenet1_1-f364aa15.pth。我实测这个方法在移动宽带和电信网络下都能成功下载。

下载完成后,你可能会遇到SSL证书验证问题。这时候需要在代码开头加上:

import ssl ssl._create_default_https_context = ssl._create_unverified_context

不过要提醒的是,这只是一个临时解决方案,在生产环境中应该配置正确的证书验证方式。

2. 模型加载的两种姿势与常见陷阱

拿到.pth文件后,新手最容易犯的错误就是直接torch.load()整个文件。用这个命令加载后,一定要先用print看看内容结构:

import torch pthfile = 'squeezenet1_1-f364aa15.pth' net = torch.load(pthfile) print(type(net)) # 输出会是OrderedDict或nn.Module

如果是OrderedDict,说明只保存了权重参数;如果是nn.Module,则是完整模型结构+参数。对于官方预训练模型,通常都是前者。这时候正确的加载姿势是:

import torchvision.models as models # 先创建空模型结构 model = models.squeezenet1_1(pretrained=False) # 然后加载权重参数 model.load_state_dict(torch.load(pthfile))

这里有个隐藏的坑:如果模型结构不匹配,会报Missing key(s) in state_dict错误。我就曾经因为用了squeezenet1_0的结构加载1_1的权重,调试了半天找不到原因。

3. 迁移学习改造实战指南

现在来到最关键的迁移学习环节。假设我们要用SqueezeNet做10分类任务,通常的操作流程是:

  1. 冻结所有底层参数
  2. 替换最后的分类层
  3. 只训练新添加的层

代码看起来很简单:

import torch.nn as nn # 加载预训练模型 model = models.squeezenet1_1(pretrained=True) # 冻结所有参数 for param in model.parameters(): param.requires_grad = False # 修改分类器 model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1))

但运行后你可能会遇到一个诡异的错误:

RuntimeError: shape '[25, 1000]' is invalid for input of size 50

这是因为SqueezeNet内部还有个num_classes属性没改!这个坑官方文档可没提醒,是我踩了三次才发现的。完整解决方案是:

model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1)) model.num_classes = 10 # 这个千万别漏!

4. 参数冻结与解冻的高级技巧

在实际项目中,我们往往不需要冻结所有层。比如对于SqueezeNet,我会选择:

  • 完全冻结前3个fire模块(特征提取层)
  • 部分解冻最后2个fire模块(特征融合层)
  • 完全解冻分类器层

具体实现代码:

# 按名称选择性冻结 for name, param in model.named_parameters(): if 'features.0' in name or 'features.3' in name or 'features.6' in name: param.requires_grad = False elif 'features.9' in name or 'features.12' in name: param.requires_grad = True # 部分解冻 else: param.requires_grad = True # 完全解冻 # 查看哪些层需要更新 params_to_update = [] for name, param in model.named_parameters(): if param.requires_grad: params_to_update.append(param) print("可训练参数:", name)

这种分层冻结策略在我的花卉分类项目中,使验证准确率提升了12%。关键是要理解网络不同层的作用——前面的卷积层提取基础特征,后面的层组合高级特征。

5. 模型保存与加载的最佳实践

训练好的模型需要妥善保存。我推荐使用以下两种方式:

  1. 保存完整模型(结构+参数):
torch.save(model, 'full_model.pth')

加载时直接model = torch.load('full_model.pth')

  1. 只保存参数(推荐):
torch.save(model.state_dict(), 'params_only.pth')

加载时需要先创建结构:

model = models.squeezenet1_1(pretrained=False) model.load_state_dict(torch.load('params_only.pth'))

特别注意:如果用第一种方式保存,加载时可能因为类定义变化导致报错。有次我升级PyTorch版本后,之前保存的模型就加载失败了。所以生产环境强烈推荐第二种方式。

6. 跨设备加载的兼容性问题

当你在GPU训练后要在CPU部署,或者反过来,会遇到经典的RuntimeError: Attempting to deserialize object on CUDA device but torch.cuda.is_available() is False。解决方案是:

# GPU保存 → CPU加载 model.load_state_dict(torch.load('gpu_model.pth', map_location=torch.device('cpu'))) # CPU保存 → GPU加载 model.load_state_dict(torch.load('cpu_model.pth', map_location='cuda:0')) model = model.cuda()

还有个更智能的写法,适合不确定部署环境的情况:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict(torch.load('model.pth', map_location=device))

7. 实战中的性能优化技巧

最后分享几个提升加载效率的技巧:

  1. 使用torch.save_use_new_zipfile_serialization参数可以减小文件体积:
torch.save(model.state_dict(), 'compressed.pth', _use_new_zipfile_serialization=False)
  1. 对于大型模型,可以分块加载:
from collections import OrderedDict state_dict = torch.load('huge_model.pth') new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('features.0'): # 只加载特定部分 new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False)
  1. 使用torch.jit.trace可以加速模型加载:
example_input = torch.rand(1, 3, 224, 224) traced_model = torch.jit.trace(model, example_input) torch.jit.save(traced_model, 'traced_model.pt')
http://www.jsqmd.com/news/659046/

相关文章:

  • 从理论到仿真:如何用Simulink的PMSM模块验证你的电机控制算法?
  • 深入解析TMS320F2803x DSP的ePWM模块:从基础配置到高级应用
  • 避坑指南:单片机串口调试时,TI和RI中断标志位那些容易踩的坑
  • 外国人为何涌向这家江南医美诊所?丽贝瑞 REBERRY 的三大核心竞争力
  • 多轮对话长上下文-向量检索和混合召回示例
  • 从电路分析到控制系统:拉普拉斯变换的工程应用避坑指南
  • Floccus实现跨浏览器书签同步
  • 从Velodyne到Livox:不同品牌激光雷达的坐标系‘方言’与ROS下的统一处理实践
  • news-please:革命性新闻爬虫工具,一站式解决新闻信息提取难题
  • 如何利用MySQLd Exporter构建企业级MySQL监控系统
  • 释放STM32的矩阵算力:ARM CMSIS-DSP库实战指南
  • SpringBoot+MyBatis实战:构建企业级CRM客户管理系统的核心模块与架构设计
  • 你的 Vue 3 defineAsyncComponent(),VuReact 会编译成什么样的 React?
  • 用手机控制电脑桌面:Lan Mouse让你的跨设备操作变得如此简单
  • MATLAB雷达仿真避坑指南:从LFM信号生成到脉冲压缩的完整流程(附代码)
  • CefFlashBrowser终极指南:如何在现代电脑上完美运行经典Flash游戏和内容
  • 鸿蒙flutter测试文章3
  • 方向向量在游戏开发中如何应用,高数下空间几何到底有什么用处
  • huatuo兼容性报告:如何无缝集成第三方库和框架
  • 10个TinyEditor实用技巧:从基础使用到高级定制
  • Go语言如何写TCP服务器_Go语言TCP Server教程【全面】
  • 终极指南:Gamescope三大后端架构解析 - DRM、SDL与Wayland实现原理深度剖析
  • Three.js动画效果
  • 软件身份管理中的用户生命周期
  • 沙特阿拉伯王储主持的沙特公共投资基金(PIF)董事会通过并公布PIF 2026-2030年战略
  • 2026年比较好的汽车叶轮注塑模具厂家哪家好 - 品牌宣传支持者
  • 【Linux】Linux环境基础开发工具使用
  • 【万字文档+PPT+源码】基于springboot+vue在线投票系统-计算机专业项目设计分享
  • AutoSpotting终极指南:如何在AWS上节省90%EC2成本
  • 实锤了!Hermes被爆抄袭中国团队代码