PyTorch模型保存与加载:从state_dict到完整模型的实战解析
1. PyTorch模型保存的两种核心方式
第一次接触PyTorch模型保存时,很多人都会困惑:为什么有时候保存的模型文件可以直接使用,有时候却要先初始化模型结构?这其实涉及到PyTorch模型持久化的两种核心策略。我在实际项目中踩过不少坑,今天就把这些经验分享给大家。
最常用的两种保存方式分别是:
- 保存整个模型(包含结构和参数)
- 仅保存state_dict(只有参数)
先看一个简单的例子。假设我们有一个训练好的CNN模型,想把它保存下来:
import torch import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3) self.fc = nn.Linear(16*28*28, 10) def forward(self, x): x = self.conv1(x) x = x.view(-1, 16*28*28) return self.fc(x) model = SimpleCNN()1.1 保存完整模型
这是最直接的方式,一行代码搞定:
torch.save(model, 'full_model.pth')这种方式会把模型的结构定义和参数值一起打包保存。加载时同样简单:
loaded_model = torch.load('full_model.pth')看起来很方便对吧?但我在实际项目中发现几个问题:
- 模型文件较大,因为包含了结构定义
- 当模型类定义发生变化时(比如修改了类名或路径),加载会失败
- 无法选择性加载部分参数
1.2 仅保存state_dict
更推荐的做法是保存state_dict:
torch.save(model.state_dict(), 'state_dict_only.pth')state_dict是PyTorch内部用来存储模型参数的字典对象,只包含参数值,不包含模型结构。加载时需要先初始化模型结构:
new_model = SimpleCNN() # 必须先创建相同结构的模型 new_model.load_state_dict(torch.load('state_dict_only.pth'))这种方式虽然多了一步,但灵活性更高。我在迁移学习场景中经常使用这种方式,可以只加载部分匹配的参数。
2. state_dict的深入解析
state_dict是理解PyTorch模型保存与加载的关键。刚开始接触时,我对这个概念也是一知半解,直到有一次调试模型加载失败的问题,才真正搞明白它的工作机制。
2.1 state_dict到底是什么?
state_dict本质上是一个Python字典,它将模型中的每个可学习参数(如权重和偏置)映射到对应的张量。举个例子,对于我们之前的SimpleCNN模型:
print(model.state_dict().keys()) # 输出:odict_keys(['conv1.weight', 'conv1.bias', 'fc.weight', 'fc.bias'])可以看到,state_dict的key是各层的名称加上参数类型(weight或bias),value就是对应的参数张量。
2.2 state_dict的高级用法
除了模型参数,优化器的state_dict也经常需要保存:
optimizer = torch.optim.Adam(model.parameters()) torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, 'checkpoint.pth')这样在恢复训练时,可以同时加载模型和优化器状态:
checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict'])我在训练大型模型时,经常使用这种checkpoint机制,可以随时中断和恢复训练过程。
3. 模型加载的常见问题与解决方案
在实际项目中,模型加载失败的情况很常见。下面分享几个我遇到过的典型问题及其解决方法。
3.1 模型结构不匹配
这是最常见的问题之一。当你尝试加载state_dict时,如果当前模型结构与保存时的结构不一致,就会报错:
# 假设我们修改了模型结构 class ModifiedCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 32, 3) # 输出通道从16改为32 self.fc = nn.Linear(32*28*28, 10) # 相应调整 def forward(self, x): x = self.conv1(x) x = x.view(-1, 32*28*28) return self.fc(x) new_model = ModifiedCNN() new_model.load_state_dict(torch.load('state_dict_only.pth')) # 会报错解决方法有两种:
- 严格保持模型结构不变
- 选择性加载匹配的参数:
pretrained_dict = torch.load('state_dict_only.pth') model_dict = new_model.state_dict() # 筛选出匹配的参数 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) new_model.load_state_dict(model_dict)3.2 设备不匹配问题
当模型在一个设备(如GPU)上训练,在另一个设备(如CPU)上加载时,可能会遇到设备不匹配的问题。我的经验是:
# 保存时指定map_location loaded_model = torch.load('model.pth', map_location=torch.device('cpu')) # 或者在加载state_dict后手动转换设备 model.load_state_dict(torch.load('state_dict.pth', map_location='cpu'))4. 实际应用场景与最佳实践
根据不同的应用场景,选择合适的模型保存和加载策略非常重要。下面分享几个典型场景下的实践经验。
4.1 模型部署场景
在模型部署时,我通常推荐保存完整模型:
torch.save(model, 'deployment_model.pth')这样部署时只需要一个文件,加载简单。但要注意:
- 确保部署环境的PyTorch版本与训练环境一致
- 模型类定义必须可访问(要么在同一个文件,要么正确导入)
4.2 迁移学习场景
做迁移学习时,state_dict方式更灵活:
# 保存预训练模型 torch.save(pretrained_model.state_dict(), 'pretrained.pth') # 在新模型上加载部分参数 new_model = NewModel() pretrained_dict = torch.load('pretrained.pth') model_dict = new_model.state_dict() # 只加载名称匹配的参数 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()} model_dict.update(pretrained_dict) new_model.load_state_dict(model_dict)4.3 多GPU训练场景
使用DataParallel或DistributedDataParallel时,模型会被包装,这时保存state_dict需要注意:
# 保存多GPU模型 model = nn.DataParallel(model) torch.save(model.module.state_dict(), 'multigpu_model.pth') # 注意使用.module # 加载时 single_model = SimpleCNN() single_model.load_state_dict(torch.load('multigpu_model.pth'))5. 性能优化与安全考虑
模型保存和加载不仅仅是功能实现,还需要考虑性能和安全性问题。这里分享一些实战经验。
5.1 文件大小优化
大型模型的文件可能非常大,可以考虑压缩保存:
# 使用zip格式压缩 torch.save(model.state_dict(), 'model_compressed.pth', _use_new_zipfile_serialization=True)我在处理ResNet等大型模型时,这种方法可以显著减小文件体积。
5.2 模型安全性
直接加载pickle格式的模型文件存在安全风险,因为pickle可以执行任意代码。建议:
- 只从可信来源加载模型
- 考虑使用torch.jit.save保存脚本化模型:
scripted_model = torch.jit.script(model) torch.jit.save(scripted_model, 'secure_model.pt')5.3 跨版本兼容性
PyTorch不同版本间可能存在兼容性问题。我的经验是:
- 尽量保持训练和部署环境一致
- 对于长期保存的模型,同时保存模型定义代码
- 考虑导出为ONNX等通用格式:
dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export(model, dummy_input, "model.onnx")6. 高级技巧与实用工具
除了基本的保存和加载,还有一些高级技巧可以提升工作效率。这些是我在项目中积累的实用经验。
6.1 模型差异比较
有时需要比较两个模型的参数差异:
def compare_models(model1, model2): for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()): if not torch.equal(param1, param2): print(f"参数 {name1} 不同") print(f"差异大小: {torch.norm(param1 - param2)}")这个函数在调试模型加载问题时非常有用。
6.2 参数冻结技巧
加载预训练模型后,经常需要冻结部分层:
for name, param in model.named_parameters(): if 'fc' not in name: # 只训练全连接层 param.requires_grad = False6.3 自定义保存格式
对于特殊需求,可以自定义保存内容:
torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'metrics': {'accuracy': acc, 'f1': f1} }, 'custom_checkpoint.pth')这种格式在科研项目中特别有用,可以保存完整的实验状态。
