当前位置: 首页 > news >正文

别再死记硬背优化器公式了!用PyTorch代码实战SGD、Momentum、Adam,看完就会用

深度学习优化器实战指南:用PyTorch代码理解SGD到Adam的核心差异

当你在PyTorch中第一次写下optim.SGD(model.parameters(), lr=0.01)时,是否好奇过这个看似简单的参数更新背后藏着怎样的数学魔法?本文将带你用代码解剖五种主流优化器的工作机制,通过可视化对比它们的实际表现,让你在工程实践中能做出明智选择。

1. 环境准备与基准模型搭建

在开始优化器对比之前,我们需要建立一个统一的实验环境。这里选择经典的MNIST手写数字识别作为测试任务,因为它足够简单又能清晰展示优化器的特性差异。

import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms import matplotlib.pyplot as plt # 数据准备 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) # 基准模型定义 class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.fc = nn.Linear(320, 10) def forward(self, x): x = torch.relu(torch.max_pool2d(self.conv1(x), 2)) x = torch.relu(torch.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 320) return self.fc(x)

这个简单的CNN包含两个卷积层和一个全连接层,足够完成MNIST分类任务。我们将保持模型结构不变,仅更换优化器来观察不同算法的表现。

2. 基础优化器:SGD与Momentum

2.1 原始SGD的实现与局限

随机梯度下降(SGD)是最基础的优化器,其更新规则简单直接:

# SGD优化器初始化 optimizer = optim.SGD(model.parameters(), lr=0.01) # 训练循环中的参数更新步骤 for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 执行w = w - lr * gradient

SGD的主要问题在于:

  • 学习率固定导致收敛速度不稳定
  • 容易陷入局部极小值或鞍点
  • 在峡谷地形(某些维度梯度大,某些维度梯度小)中震荡明显

2.2 引入动量的改进方案

动量法通过积累历史梯度信息来平滑更新方向:

# 带动量的SGD optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 更新过程变为: # v = momentum * v + gradient # w = w - lr * v

动量参数(通常设为0.9)决定了历史梯度的影响程度。我们通过对比实验展示差异:

优化器类型训练损失(10 epoch)测试准确率收敛稳定性
原始SGD0.3592.1%波动较大
SGD+Momentum0.2194.7%平滑稳定

提示:动量法特别适合处理损失函数表面存在大量局部极小值的情况,能帮助参数更新"冲过"这些障碍。

3. 自适应学习率优化器:Adagrad与RMSProp

3.1 Adagrad的自适应策略

Adagrad通过累计梯度平方来自适应调整每个参数的学习率:

optimizer = optim.Adagrad(model.parameters(), lr=0.01) # 更新规则: # r = r + gradient^2 # w = w - lr / (sqrt(r) + eps) * gradient

Adagrad的特点:

  • 稀疏特征对应的参数会获得更大的更新
  • 随着训练进行,学习率会自动衰减
  • 适合处理稀疏数据

3.2 RMSProp的改进方案

RMSProp通过引入衰减系数解决Adagrad学习率单调下降的问题:

optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99) # 更新规则: # r = alpha * r + (1-alpha) * gradient^2 # w = w - lr / (sqrt(r) + eps) * gradient

关键参数alpha控制历史信息的衰减速度。对比实验数据:

优化器训练速度最终性能学习率衰减情况
Adagrad初期快后期慢93.5%单调下降
RMSProp稳定快速95.2%动态平衡

4. 综合王者:Adam优化器

Adam结合了动量法和RMSProp的优点,成为当前最流行的优化器:

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) # 更新过程: # 1. 计算动量:m = beta1*m + (1-beta1)*gradient # 2. 计算梯度平方:v = beta2*v + (1-beta2)*gradient^2 # 3. 偏差校正:m_hat = m/(1-beta1^t), v_hat = v/(1-beta2^t) # 4. 参数更新:w = w - lr * m_hat/(sqrt(v_hat)+eps)

Adam的核心优势在于:

  • 自动调整每个参数的学习率
  • 内置动量机制加速收敛
  • 偏差校正确保初期稳定性

不同优化器的损失曲线对比:

# 绘制训练曲线 plt.figure(figsize=(10,6)) for opt_name, losses in loss_history.items(): plt.plot(losses, label=opt_name) plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.show()

5. 工程实践中的选择策略

在实际项目中,优化器选择需要考虑以下因素:

数据集特性:

  • 小规模数据:SGD或带动量的SGD
  • 大规模稀疏数据:Adagrad
  • 一般情况:Adam/RMSProp

模型架构:

  • 传统CNN:Adam
  • RNN/LSTM:RMSProp或Adam
  • 预训练模型微调:带动量的SGD

超参数调优技巧:

  • 学习率:Adam通常从3e-4开始尝试
  • 动量参数:0.9是良好起点
  • Adam的beta1/beta2:除非特殊需求,否则保持默认
# 典型Adam参数配置示例 optimizer = optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

注意:虽然Adam被称为"无需调参",但适当调整初始学习率和weight_decay仍能显著提升性能。在训练后期,切换到SGD有时能获得更好的最终精度。

在最近参与的图像分割项目中,我发现当训练数据存在明显类别不平衡时,配合适当学习率衰减策略的AdamW(Adam+权重衰减)表现最为稳定。特别是在训练初期,自适应学习率能有效缓解头部类别主导训练的问题。

http://www.jsqmd.com/news/1010899/

相关文章:

  • 别只当操作手册!深入解读SAP FIORI ICMR对账App的设计逻辑与业务价值
  • Anthropic协议直通架构:消除LLM服务胶水层实现延迟归零
  • 写文章10分钟_发平台1小时_用AI内容多平台适配把时间抢回来
  • 2026池州全城黄金回收口碑商户盘点 TOP铂金回收白银回收旧料回收门店电话地址一览 - 信誉隆金银铂奢回收
  • 2026臻选:上城区四季青疏通下水道 724 小时运维保障 居顺联家政疏通靠谱服务详解 - 居顺联家政疏通
  • 2026东营市民高频光顾的 5 家线下黄金回收白银铂金回收实体店实地走访测评 - 中安检金银铂钻回收
  • 2026河南本地贵金属变现门店精选前五+黄金铂金白银金条回收合规商家名录 含地址电话 - 诚金汇钻回收公司
  • 2026阜阳本地贵金属变现门店精选前五+黄金铂金白银金条回收合规商家名录 含地址电话 - 诚金汇钻回收公司
  • 2026实力之选:机床空调及机柜电柜电箱控制箱无冷水空调制造工厂深度解析 - 品牌发掘
  • CNN中Pooling层的本质:空间鲁棒性构建与实战避坑指南
  • Python Turtle 画生日蛋糕保姆级教程:从数学函数到动画效果的完整实现
  • SAP CK11N成本估算实战:BAPI与BDC两种自动化方案详解与避坑指南
  • 多平台发文最烦调格式_AI自动排版发布帮我搞定了
  • 大连瓦房店市专业房屋漏水检测,精准定位漏水点,快速解决各类渗漏难题-2026年大连房屋漏水检测推荐公司 - 同城资讯
  • 2026大理市民高频光顾的 5 家线下黄金回收白银铂金回收实体店实地走访测评 - 中安检金银铂钻回收
  • 2026海口本地贵金属变现门店精选前五+黄金铂金白银金条回收合规商家名录 含地址电话 - 诚金汇钻回收公司
  • Kimi K2.6 快速 LeetCode 3213. 最小代价构造字符串 Go实现
  • 猫抓Cat-Catch终极指南:3分钟成为浏览器资源嗅探专家
  • 如何快速掌握Typora自动编号功能:面向新手的完整实战指南
  • 内容发了没人看_AI智能发布时机可能是你忽略的那块短板
  • 2026阜阳市民高频光顾的 5 家线下黄金回收白银铂金回收实体店实地走访测评 - 中安检金银铂钻回收
  • 本地部署Llama-3.1替代Claude API的实战指南
  • 2026广西本地贵金属变现门店精选前五+黄金铂金白银金条回收合规商家名录 含地址电话 - 诚金汇钻回收公司
  • 2026成都市民高频光顾的 5 家线下黄金回收白银铂金回收实体店实地走访测评 - 中安检金银铂钻回收
  • 2026保山市民高频光顾的 5 家线下黄金回收白银铂金回收实体店实地走访测评 - 中安检金银铂钻回收
  • C++并发编程选型指南:何时该用无锁队列concurrentqueue,何时用STL queue就够了?
  • 2026鄂尔多斯市民高频光顾的 5 家线下黄金回收白银铂金回收实体店实地走访测评 - 中安检金银铂钻回收
  • 手动发五六个平台太累了_AI全渠道发布是不是解法
  • 避坑必看!2026上海奢侈品黄金回收TOP6实测:机构套路大起底,零套路诚信标杆出炉 - 奢侈品回收
  • 河南郑州GEO服务商如何选择更合适?