当前位置: 首页 > news >正文

112_深度学习的导航仪:PyTorch 优化器(Optimizer)全解析

在经历了前向传播计算 Loss、反向传播计算梯度(Gradient)后,我们来到了最关键的一步:更新参数。优化器就像是一位经验丰富的导航员,它根据梯度指示的方向,决定如何调整模型的权重,使 Loss 降到最低。

1. 优化器的核心逻辑

优化器的主要工作包含以下三个步骤,缺一不可:

  1. 梯度清零 (zero_grad):在每一轮计算开始前,必须把之前的梯度清空,否则梯度会不断累加,导致训练出错。
  2. 反向传播 (backward):计算当前误差对每个参数的梯度。
  3. 参数更新 (step):根据选定的算法(如 SGD、Adam)和梯度值,实际修改网络中的权重。

2. 实战代码:神经网络的完整训练循环

通过 CIFAR-10 数据集演示了如何使用SGD(随机梯度下降)优化器进行多轮(Epoch)训练。

import torch import torchvision from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True) dataloader = DataLoader(dataset, batch_size=64,drop_last=True) class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.model1 = Sequential( Conv2d(3,32,5,padding=2), MaxPool2d(2), Conv2d(32,32,5,padding=2), MaxPool2d(2), Conv2d(32,64,5,padding=2), MaxPool2d(2), Flatten(), Linear(1024,64), Linear(64,10) ) def forward(self, x): x = self.model1(x) return x loss = nn.CrossEntropyLoss() # 交叉熵 tudui = Tudui() optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # 随机梯度下降优化器 for epoch in range(20): running_loss = 0.0 for data in dataloader: imgs, targets = data outputs = tudui(imgs) result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距 optim.zero_grad() # 梯度清零 result_loss.backward() # 反向传播,计算损失函数的梯度 optim.step() # 根据梯度,对网络的参数进行调优 running_loss = running_loss + result_loss print(running_loss) # 对这一轮所有误差的总和

3. 进阶:学习率调整策略 (LR Scheduler)

文件中还提到了一个进阶工具:StepLR

  • 为什么要调整学习率?训练初期我们希望走得快(学习率大),训练后期为了精准落入最低点,我们需要走得慢(学习率小)。
  • 代码实现
import torch import torchvision from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter dataset = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True) dataloader = DataLoader(dataset, batch_size=64,drop_last=True) class Tudui(nn.Module): def __init__(self): super(Tudui, self).__init__() self.model1 = Sequential( Conv2d(3,32,5,padding=2), MaxPool2d(2), Conv2d(32,32,5,padding=2), MaxPool2d(2), Conv2d(32,64,5,padding=2), MaxPool2d(2), Flatten(), Linear(1024,64), Linear(64,10) ) def forward(self, x): x = self.model1(x) return x loss = nn.CrossEntropyLoss() # 交叉熵 tudui = Tudui() optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # 随机梯度下降优化器 scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5, gamma=0.1) # 每过 step_size 更新一次优化器,更新是学习率为原来的学习率的的 0.1 倍 for epoch in range(20): running_loss = 0.0 for data in dataloader: imgs, targets = data outputs = tudui(imgs) result_loss = loss(outputs, targets) # 计算实际输出与目标输出的差距 optim.zero_grad() # 梯度清零 result_loss.backward() # 反向传播,计算损失函数的梯度 optim.step() # 根据梯度,对网络的参数进行调优 scheduler.step() # 学习率太小了,所以20个轮次后,相当于没走多少 running_loss = running_loss + result_loss print(running_loss) # 对这一轮所有误差的总和

4. 总结:训练全流程闭环

分析完此文件后,我们终于完成了 PyTorch 训练的完整拼图:

  1. 准备数据(Dataset & DataLoader)
  2. 搭建结构(nn.Module & Sequential)
  3. 衡量误差(Loss Function)
  4. 计算方向(Backward)
  5. 调整参数(Optimizer)

💡 学习心得

优化器的lr(学习率)设置非常关键。设置过大,模型可能在最低点附近反复横跳无法收敛;设置过小,模型学习速度会极慢。在实际开发中,Adam优化器由于其自适应学习率的特性,通常比SGD更容易上手。

http://www.jsqmd.com/news/517961/

相关文章:

  • 香橙派 AIpro 实战:从零部署 YOLOv8 模型避坑指南(附昇腾 ATC 转换技巧)
  • UE5 蓝图入门 - 从零开始构建你的第一个交互功能
  • 不用写代码!手把手教你用ChatGPT+开源工具自动生成专业PPT(附避坑指南)
  • JVM面试杂知识
  • 探索虚拟同步发电机的MATLAB仿真之旅
  • Qwen与MinerU文档处理对比:哪个更适合中小企业自动化办公场景?
  • 通义千问2.5-7B保姆级教程:零基础5分钟本地部署,小白也能玩转AI对话
  • 【技术揭秘】快速识别网站服务器类型:Nginx与Apache的实战技巧
  • 【HALCON工业视觉应用探索】15. 项目全生命周期管理:从需求到交付的全流程详解
  • AI原生应用与决策支持的融合发展路径探讨
  • Visio中高效插入与编辑矩阵公式的完整指南
  • 【架构心法】删掉多线程!撕开通信死锁的黑盒,用 C++ 单线程状态机重塑极速 ACK 与重传引擎
  • 深度学习必备技能:5分钟用Python画出ReLU家族函数图像(含PReLU参数调整技巧)
  • ICML 2025 | 贝叶斯熵 + 多模态提示,USAM 重新定义 SAM 不确定性量化框架
  • Vue项目登录页刷新报错?手把手教你解决‘undefined is not valid JSON‘问题
  • 用Python和NumPy手把手实现多智能体仿射队形控制(附完整代码与避坑指南)
  • 嵌入式开发实战:MIPI-DSI与I2C接口在LCD触控屏中的协同工作原理
  • 别再死记硬背Attention了!用Python手写一个Seq2Seq翻译模型,直观理解Encoder-Decoder的瓶颈
  • 内存池监控不是加个malloc钩子就够了!揭秘某智能电网项目因监控粒度粗0.1ms导致的3次I级事故
  • 基于RexUniNLU的智能内容审核系统开发
  • AutoJs悬浮窗实战:从零打造可拖拽控制面板(附完整源码解析)
  • 告别CNN黑箱?用Vision Transformer做医学影像分割的实战避坑指南
  • 低成本改造阳台小菜园:用Arduino+继电器模块实现定时滴灌系统
  • Transformer模型中的自注意力机制:从零开始手把手实现(附Python代码)
  • FLAC3D耦合PFC3D隧道开挖模拟:位移连续性与地表沉降规律
  • 大班匠搬家公司联系方式:关于选择专业搬家服务提供商的使用指南与行业普遍注意事项 - 品牌推荐
  • 15 三数之和
  • 北京名人手抄本、老医书、族谱上门回收,线装古籍全品类收 - 品牌排行榜单
  • 【Dify高阶实战指南】:3个生产级异步节点自定义陷阱,90%团队部署后才后悔没看
  • FLAC3D与PFC3D耦合边坡模型,位移连续性优异