当前位置: 首页 > news >正文

Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

相关阅读

Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

下面首先复现这个bug。

import torch import torch.nn as nn # 定义一个简单的线性模型,参数类型为整数 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量 # 创建一个简单模型实例 model = SimpleModel() # 创建一个浮点数作为参数 float_parameter = torch.tensor(0.6) # 将注册名指向另一个浮点型张量 model.test = float_parameter # 保存模型 torch.save(model.state_dict(), 'model.pth') # 直接使用原模型加载 checkpoint = torch.load('model.pth') model.load_state_dict(checkpoint) # 打印加载后的参数 print(model.test) # 直接使用新模型加载 model_1 = SimpleModel() model_1.load_state_dict(checkpoint) # 打印加载后的参数 print(model_1.test)
输出: tensor(0.6000) tensor(0)

可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

import torch # 创建两个张量 a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[5.1, 6.1], [7.1, 8.1]]) # 查看张量对象的id print(id(a)) print(id(b)) # 查看底层存储的内存地址 print(a.storage().data_ptr()) print(b.storage().data_ptr()) # 将张量 b 中的值复制到张量 a 中 a.copy_(b) # 打印复制后的结果 print(a) # 查看张量对象的id print(id(a)) print(id(b)) # 查看底层存储的内存地址 print(a.storage().data_ptr()) print(b.storage().data_ptr())
输出: 2604425272672 2604426953808 2604511348096 2602930352832 tensor([[5, 6], [7, 8]]) 2604425272672 2604426953808 2604511348096 2602930352832

在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。

http://www.jsqmd.com/news/862385/

相关文章:

  • 工业眼睛:11 老手血泪Tips + 新手避坑清单
  • 2026年靠谱的浙江时效物流快运/龙港物流快运售后无忧公司 - 行业平台推荐
  • Agent Runtime 正在 commoditize:从 session-as-event-log 看 AI 基础设施分层
  • ishell 错误处理与中断机制:构建健壮的交互式应用
  • 数据结构知识点
  • 2026年北京市外资研发中心(第九批)认定通知
  • 2026年口碑好的合肥GEO排名优化/安徽GEO排名优化推荐榜单公司 - 行业平台推荐
  • AI能力评估中的事实核查与术语规范
  • Vue3 入门到进阶:vite 搭建、响应式原理与新组件实战
  • CANN/asc-devkit int8转half API文档
  • 2026年05月智慧泵房优选:口碑与实力并存的公司,供水控制柜/光伏太阳能供水设备/长轴消防泵,智慧泵房制造厂家推荐 - 品牌推荐师
  • 智慧树刷课插件:3个功能让你告别手动操作,节省50%学习时间
  • 保姆级教程:用Conda为Stable Diffusion WebUI创建纯净Python环境,彻底告别启动崩溃
  • DeepCreamPy图像修复终极指南:AI智能去码快速上手教程
  • 告别Transformer卡顿!用SegMamba在3D医学图像分割上实现又快又准(附BraTS2023实战代码)
  • Airflow Maintenance Dags项目架构深度剖析:从代码实现到生产部署
  • 2026年比较好的5G数据采集网关/深圳边缘计算数据采集网关/定位和锁机远程运维网关/深圳5G数据采集网关用户好评公司 - 品牌宣传支持者
  • NotaGen终极指南:基于大语言模型的高质量古典乐谱生成解决方案
  • 从手机摄像头到天文望远镜:一文搞懂CCD传感器是如何‘看见’世界的
  • windows8080端口被占用 ?
  • AD7616前端设计避坑指南:RCR滤波器如何影响谐波测量精度?从硬件到软件的补偿思路
  • 数字电路-74LS148的5路呼叫显示和74LS373的8路抢答器
  • CANN/pypto张量创建指南
  • Musicn安全使用指南:避免版权风险的最佳实践
  • 2026年推荐哈尔滨铜门公司选择指南 - 品牌宣传支持者
  • Windows 7 SP2终极解决方案:三步告别硬件兼容性问题,让经典系统焕发新生
  • Gemini赋能安全工程师:自动生成PoC脚本的技术实践
  • GitHub Desktop中文汉化终极指南:5分钟让英文界面变中文
  • Sixpack Redis数据存储策略:高效管理A/B测试数据的10个技巧
  • Mainframer错误排查指南:常见问题及解决方法大全