当前位置: 首页 > news >正文

加载权重文件后发现准确率有问题

保存权重文件时,最好使用copy.deepcopy,不然可能出现引用的问题,导致本应该保存best pth的变成保存最后一个epoch的pth。

/root/unified_nas/training/trainer.py

# 更新最佳模型 if val_metrics['accuracy'] > best_accuracy: best_accuracy = val_metrics['accuracy'] best_val_metrics = val_metrics best_model_state = { # 'model': self.model.state_dict(), # 'head': self.task_head.state_dict() 'model': copy.deepcopy(self.model.state_dict()), # ✅ 深拷贝 'head': copy.deepcopy(self.task_head.state_dict()) # ✅ 深拷贝 } # 保存最佳模型权重到文件 torch.save(best_model_state, save_path) # print(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%") self._output(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%")

这部分的完整代码如下:

import torch import torch.nn as nn from torch.optim import Adam from tqdm import tqdm import numpy as np from collections import defaultdict import copy # 设置随机数种子 SEED = 42 # 你可以选择任何整数作为种子 torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) class SingleTaskTrainer: """ 针对单个数据集的训练器 """ def __init__(self, model, dataloaders, device='cuda', logger=None): """ 初始化训练器 参数: model: 要训练的模型 dataloaders: 数据加载器字典,包含 'train' 和 'test' 两个键 device: 训练设备 ('cuda' 或 'cpu') """ self.model = model.to(device) self.dataloaders = dataloaders self.device = device self.logger = logger # 如果没有提供logger,创建一个简单的logger来模拟print行为 # 确保模型有 output_dim 属性 if not hasattr(model, 'output_dim'): raise AttributeError("Model must have 'output_dim' attribute") # 获取类别数 self.num_classes = len(dataloaders['train'].dataset.classes) print(f"Number of classes: {self.num_classes}") # 创建任务头 self.task_head = nn.Linear(model.output_dim, self.num_classes).to(device) # 定义损失函数和优化器 self.criterion = nn.CrossEntropyLoss() self.optimizer = Adam( list(model.parameters()) + list(self.task_head.parameters()), lr=1e-3 ) def _output(self, message): """统一的输出方法:如果有logger则使用logger,否则使用print""" if self.logger: self.logger.info(message) else: print(message) def train_epoch(self): """ 单个训练周期 """ self.model.train() self.task_head.train() running_loss = 0.0 correct = 0 total = 0 for inputs, labels in tqdm(self.dataloaders['train'], desc="Training"): inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() features = self.model(inputs) outputs = self.task_head(features) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() metrics = { 'loss': running_loss / len(self.dataloaders['train']), 'accuracy': 100. * correct / total } return metrics def evaluate(self): """ 模型评估 """ self.model.eval() self.task_head.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for inputs, labels in tqdm(self.dataloaders['test'], desc="Evaluating"): inputs = inputs.to(self.device) labels = labels.to(self.device) features = self.model(inputs) outputs = self.task_head(features) loss = self.criterion(outputs, labels) running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() metrics = { 'loss': running_loss / len(self.dataloaders['test']), 'accuracy': 100. * correct / total } return metrics def train(self, epochs=10, save_path='best_model.pth'): """ 训练模型并保存最佳权重 参数: epochs: 训练周期数 save_path: 最佳模型权重保存路径 返回: best_accuracy: 最佳验证准确率 best_val_metrics: 最佳验证指标 history: 训练历史记录 best_model_state: 最佳模型状态字典 """ best_accuracy = 0.0 best_val_metrics = None # 保存最佳验证指标 history = [] best_model_state = None # 保存最佳模型状态 for epoch in range(epochs): # print(f"\nEpoch {epoch + 1}/{epochs}") self._output(f"\nEpoch {epoch + 1}/{epochs}") # 训练阶段 train_metrics = self.train_epoch() # 验证阶段 val_metrics = self.evaluate() # 保存历史 history.append({ 'train': train_metrics, 'val': val_metrics }) # print(f"\nValidation Accuracy: {val_metrics['accuracy']:.2f}%") self._output(f"\nValidation Accuracy: {val_metrics['accuracy']:.2f}%") # 更新最佳模型 if val_metrics['accuracy'] > best_accuracy: best_accuracy = val_metrics['accuracy'] best_val_metrics = val_metrics best_model_state = { # 'model': self.model.state_dict(), # 'head': self.task_head.state_dict() 'model': copy.deepcopy(self.model.state_dict()), # ✅ 深拷贝 'head': copy.deepcopy(self.task_head.state_dict()) # ✅ 深拷贝 } # 保存最佳模型权重到文件 torch.save(best_model_state, save_path) # print(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%") self._output(f"✅ Best model saved with accuracy: {best_accuracy:.2f}%") return best_accuracy, best_val_metrics, history, best_model_state

Note:不要轻易改模型结构,与其改模型结构来调整问题,不如相信就按照现在的结构继续做。问题往往意想不到。

http://www.jsqmd.com/news/294356/

相关文章:

  • 2026英语雅思培训学校机构辅导机构排行榜 家长择校完全指南:多维度评测帮孩子选对适配辅导机构
  • 2026英语雅思学习辅导机构推荐榜单 家长择校完全指南:多维度评测解析帮孩子选对适配机构
  • 2026英语雅思学习辅导机构排行榜+核心解析 家长择校实用指南 帮孩子精准匹配雅思学习全阶段适配方案避误区
  • 并查集及其应用专题--全网最详细版
  • 聚焦5家瑞祥卡回收1分钟高效操作平台
  • 2025年目前靠谱的花灯企业推荐榜单,春节国潮花灯/十二生肖花灯/宫灯/互动花灯/营销花灯/古镇花灯,花灯实力厂家哪家好
  • 2026英语雅思培训学校机构辅导机构排行榜+核心解析 家长择校实用指南 精准匹配孩子备考需求
  • 2026英语雅思学习辅导机构排行榜 家长择校实用指南:多维度评测帮孩子选对适配学习机构
  • 2026英语雅思学习辅导机构排行榜+核心解析 家长择校实用指南 帮孩子精准匹配适配的雅思备考方案
  • 四川高中复读学校推荐:家长关注的几所学校,高中/实验中学/学校/中学/高中复读学校,高中复读学校生产厂家联系方式
  • AI赋能创始人表达:从个人智慧到组织能力的战略跃迁
  • 2026年合肥中职择校指南:五大口碑校深度解析与趋势前瞻
  • 创始人IP:新质生产力时代,企业的“人格化”护城河
  • 2026英语雅思培训班辅导机构推荐榜单 家长择校完全指南:多维度评测解析帮孩子选对适配机构
  • 2026英语雅思培训班辅导机构排行榜+核心解析 家长择校实用指南 帮孩子精准匹配雅思备考全阶段适配方案
  • 合肥对口高考院校深度测评与选择指南(2026届考生必读)
  • 高校毕业生实习及就业去向信息管理系统(编号:3394424) --论文vue3
  • 2026英语雅思学习辅导机构排行榜+核心解析 家长择校完全指南 精准匹配孩子备考需求避误区
  • 基于 SpringCloud 的作品投票系统vue3
  • 全国通用的京东e卡回收多种任选渠道
  • 基于python的家教预约服务平台vue3
  • 2026英语雅思培训班辅导机构排行榜+核心解析 家长择校完全指南 帮孩子精准匹配适配的雅思备考方案避误区
  • 基于Spring Cloud技术的智慧云停车场服务管理系统vue3
  • 40种绕过WAF防火墙的Payload混淆技术,从零基础到精通,收藏这篇就够了!_waf绕过技战术
  • 基于spring mvc和mybatis的网上食品零食商城系统视频vue3
  • 基于的城市公交查询地图系统(编号:1410396)--论文vue3
  • Java毕设选题推荐:基于springboot的智慧生产安全系统安全巡检系统的设计与实现【附源码、mysql、文档、调试+代码讲解+全bao等】
  • 【时间之外】AI招聘这么干行不行?
  • 计算机毕业设计Python深度学习物流网络优化与货运路线规划系统 智慧交通 机器学习 大数据毕设(源码 +LW文档+PPT+讲解)
  • Java毕设项目:基于springboot的游戏售卖商城系统(源码+文档,讲解、调试运行,定制等)