当前位置: 首页 > news >正文

别再死记硬背Adam公式了!用Python手搓一个Adam优化器,彻底搞懂偏差修正和矩估计

别再死记硬背Adam公式了!用Python手搓一个Adam优化器,彻底搞懂偏差修正和矩估计

在机器学习的优化算法领域,Adam(Adaptive Moment Estimation)无疑是最耀眼的明星之一。但很多人在使用Adam时,只是机械地调用torch.optim.Adam()keras.optimizers.Adam(),对其内部工作原理一知半解。今天,我们将打破这种"黑箱"使用模式,从零开始用Python实现一个完整的Adam优化器,通过代码深入理解其核心机制。

1. 优化器基础:从SGD到Adam的进化之路

在进入Adam的具体实现之前,我们需要了解优化器的发展脉络。传统的随机梯度下降(SGD)虽然简单直接,但在实际应用中存在几个明显缺陷:

  • 学习率需要手动调整,且对所有参数使用相同的学习率
  • 在损失函数的某些方向(如峡谷地形)容易振荡
  • 对于稀疏梯度处理效果不佳

为了解决这些问题,研究者们提出了一系列改进算法:

  1. Momentum:引入"动量"概念,考虑历史梯度信息
  2. AdaGrad:为不同参数自适应调整学习率
  3. RMSprop:改进AdaGrad的梯度累积方式

Adam则综合了这些算法的优点,成为当前最流行的优化器之一。它的核心创新在于:

  • 同时计算梯度的一阶矩估计(均值)和二阶矩估计(未中心化的方差)
  • 对这两个估计进行偏差修正
  • 自适应地为每个参数计算不同的学习率

2. Adam的核心算法解析

Adam的核心算法可以分为以下几个步骤:

  1. 计算当前梯度的一阶矩估计(均值)和二阶矩估计(方差)
  2. 对这些估计进行偏差修正
  3. 根据修正后的估计更新参数

用数学公式表示如下:

m_t = β1 * m_{t-1} + (1 - β1) * g_t # 一阶矩估计 v_t = β2 * v_{t-1} + (1 - β2) * g_t^2 # 二阶矩估计 m̂_t = m_t / (1 - β1^t) # 偏差修正后的一阶矩 v̂_t = v_t / (1 - β2^t) # 偏差修正后的二阶矩 θ_t = θ_{t-1} - α * m̂_t / (√v̂_t + ε) # 参数更新

其中:

  • g_t是当前梯度
  • β1β2是指数衰减率(通常取0.9和0.999)
  • α是学习率
  • ε是一个极小值(通常1e-8)用于数值稳定性

3. 从零实现Adam优化器

现在,让我们用Python和NumPy来实现一个完整的Adam优化器。我们将创建一个类,模仿PyTorch优化器的接口风格。

import numpy as np class AdamOptimizer: def __init__(self, params, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8): """ 初始化Adam优化器 参数: params: 待优化的参数列表 lr: 学习率 (默认: 0.001) beta1: 一阶矩估计的指数衰减率 (默认: 0.9) beta2: 二阶矩估计的指数衰减率 (默认: 0.999) eps: 数值稳定项 (默认: 1e-8) """ self.params = params self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.t = 0 # 时间步 # 初始化一阶和二阶矩估计 self.m = [np.zeros_like(p) for p in params] self.v = [np.zeros_like(p) for p in params] def step(self, grads): """ 执行一步参数更新 参数: grads: 对应参数的梯度列表 """ self.t += 1 for i, (param, grad) in enumerate(zip(self.params, grads)): # 更新一阶矩估计 self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad # 更新二阶矩估计 self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (grad ** 2) # 计算偏差修正后的估计 m_hat = self.m[i] / (1 - self.beta1 ** self.t) v_hat = self.v[i] / (1 - self.beta2 ** self.t) # 更新参数 param -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)

这个实现包含了Adam的所有关键组件:

  • 一阶矩估计(m)和二阶矩估计(v)的维护
  • 偏差修正的计算
  • 参数更新逻辑

4. 偏差修正的直观理解

Adam中的偏差修正是初学者最容易困惑的部分。为什么需要修正?让我们通过一个简单的例子来理解。

假设我们有一个常数梯度g_t = 1β1 = 0.9,初始m_0 = 0。前几个时间步的矩估计为:

m_1 = 0.9 * 0 + 0.1 * 1 = 0.1 m_2 = 0.9 * 0.1 + 0.1 * 1 = 0.19 m_3 = 0.9 * 0.19 + 0.1 * 1 = 0.271 ...

可以看到,初始阶段的矩估计明显低于真实梯度值(1)。偏差修正通过除以(1 - β^t)来补偿这种低估:

m̂_1 = 0.1 / (1 - 0.9^1) = 0.1 / 0.1 = 1 m̂_2 = 0.19 / (1 - 0.9^2) = 0.19 / 0.19 ≈ 1 m̂_3 = 0.271 / (1 - 0.9^3) = 0.271 / 0.271 ≈ 1

随着t增大,β^t趋近于0,修正因子趋近于1,修正的影响逐渐消失。

5. 在MNIST数据集上验证我们的实现

为了验证我们的Adam实现是否正确,我们将在MNIST数据集上训练一个简单的全连接网络,并与PyTorch内置的Adam优化器进行对比。

import torch import torch.nn as nn import torchvision from torchvision import transforms # 准备MNIST数据集 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) # 定义一个简单的全连接网络 class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.fc1 = nn.Linear(28*28, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 28*28) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # 初始化两个相同的网络 net1 = SimpleNet() net2 = SimpleNet() net2.load_state_dict(net1.state_dict()) # 确保初始参数相同 # 使用PyTorch的Adam优化器训练net1 optimizer1 = torch.optim.Adam(net1.parameters(), lr=0.001) # 使用我们的Adam实现训练net2 params = [p.data.numpy() for p in net2.parameters()] optimizer2 = AdamOptimizer(params, lr=0.001) criterion = nn.CrossEntropyLoss() for epoch in range(5): for i, (inputs, labels) in enumerate(trainloader): # 训练net1 (PyTorch Adam) optimizer1.zero_grad() outputs1 = net1(inputs) loss1 = criterion(outputs1, labels) loss1.backward() optimizer1.step() # 训练net2 (我们的Adam) outputs2 = net2(inputs) loss2 = criterion(outputs2, labels) net2.zero_grad() loss2.backward() # 获取梯度并更新参数 grads = [p.grad.data.numpy() for p in net2.parameters()] optimizer2.step(grads) # 比较两个网络的参数差异 if i % 100 == 0: diff = sum(np.sum(np.abs(p1.data.numpy() - p2)) for p1, p2 in zip(net1.parameters(), params)) print(f'Epoch {epoch}, Batch {i}, Parameter difference: {diff:.6f}')

如果我们的实现正确,两个网络的参数差异应该保持在一个很小的范围内(由于浮点计算顺序等微小差异)。

6. Adam的变体与实践技巧

虽然标准Adam已经表现很好,但在实践中还有一些变体和技巧值得关注:

  1. AMSGrad:解决Adam可能收敛到次优解的问题
  2. AdamW:将权重衰减与梯度更新解耦
  3. NAdam:结合Nesterov动量的Adam变体

在实际使用Adam时,有几个经验性的建议:

  • 默认参数(lr=0.001,beta1=0.9,beta2=0.999)在大多数情况下表现良好
  • 对于特别深或复杂的网络,可能需要调低学习率
  • 在训练初期,可以观察损失曲线判断是否需要调整beta1beta2
  • 配合学习率调度器(如ReduceLROnPlateau)使用效果更佳

7. 可视化Adam的内部状态

为了更直观地理解Adam的工作原理,我们可以可视化训练过程中一些关键量的变化:

import matplotlib.pyplot as plt # 假设我们有一个简单的二次函数 f(x) = x^2 x = np.linspace(-2, 2, 100) y = x**2 # 初始化参数和优化器 param = np.array([1.8]) # 初始位置 optimizer = AdamOptimizer([param], lr=0.1) # 存储轨迹和内部状态 trajectory = [param.copy()] ms = [] vs = [] for t in range(1, 101): grad = 2 * param # f(x)=x^2的梯度是2x optimizer.step([grad]) trajectory.append(param.copy()) ms.append(optimizer.m[0].copy()) vs.append(optimizer.v[0].copy()) # 绘制优化轨迹和函数曲线 plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(x, y, label='f(x) = x²') plt.plot(trajectory, [p**2 for p in trajectory], 'ro-', label='Optimization path') plt.xlabel('x') plt.ylabel('f(x)') plt.legend() # 绘制m和v的变化 plt.subplot(1, 2, 2) plt.plot(ms, label='First moment (m)') plt.plot(vs, label='Second moment (v)') plt.xlabel('Iteration') plt.ylabel('Value') plt.legend() plt.show()

这样的可视化可以帮助我们理解:

  • Adam如何结合历史梯度信息(m)和梯度幅度信息(v)
  • 偏差修正在初期如何影响更新
  • 自适应学习率如何根据梯度特性调整

8. 常见问题与调试技巧

在实际使用Adam时,可能会遇到一些典型问题:

问题1:训练初期损失下降缓慢

  • 可能原因:偏差修正导致初期更新步长过小
  • 解决方案:适当提高初始学习率或减小beta1

问题2:训练后期出现振荡

  • 可能原因:二阶矩估计v变得太小
  • 解决方案:检查eps值是否合适,或尝试AMSGrad变体

问题3:模型收敛到次优解

  • 可能原因:自适应学习率导致某些方向更新不足
  • 解决方案:尝试结合SGD或使用学习率预热策略

调试Adam时,可以监控以下指标:

  • 梯度的一阶矩和二阶矩的统计量
  • 参数更新的相对幅度
  • 不同层的学习率比例

在实现自己的优化器时,确保数值稳定性至关重要。特别是:

  • 处理sqrt(v_hat)时添加足够小的eps
  • 注意浮点数精度问题
  • 在偏差修正中避免除以零
http://www.jsqmd.com/news/743572/

相关文章:

  • 多模态提示词实战指南:解锁GPT-4V与DALL·E 3高效应用
  • SD-PPP:如何通过插件架构革命实现创意工作流的无缝融合
  • 如何用深度学习实现95%准确率的实时手语翻译系统?
  • 基于计算机视觉与自动化控制技术的游戏辅助系统:MaaAssistantArknights深度解析
  • 【技术解密】Jasminum:破解中文文献管理难题的智能元数据引擎
  • Warcraft Helper:深度解析魔兽争霸III现代兼容性解决方案
  • CefFlashBrowser终极指南:在Windows上完美运行Flash游戏和内容的完整教程
  • 手机号码定位工具终极指南:3步快速查询归属地
  • 字幕自动化管理:ajnart/subs工具实战与媒体库集成指南
  • 告别Root!在Termux里用Ubuntu创建普通用户的保姆级避坑指南
  • 魔兽争霸III兼容性问题终极解决方案:Warcraft Helper插件全攻略
  • 如何高效制作Fedora系统启动盘:跨平台工具完整指南
  • KeymouseGo:三分钟学会鼠标键盘自动化,让你的工作效率提升300%
  • ShareX:集屏幕截图、文件共享与生产力工具于一体,多渠道获取信息!
  • RAG技术如何优化LLM在垂直领域的知识检索
  • 4D内容生成与重建:解耦LoRA控制技术解析
  • 阿里云2026年5月Hermes Agent/OpenClaw如何部署?百炼token Plan配置
  • Godot引擎WebAssembly部署实战:优化构建与网页游戏开发指南
  • 基于MCP协议的AI驱动部署编排:用自然语言自动化开发工作流
  • PEARL模型:个性化视频理解的动态注意力机制解析
  • Claude桌面应用深度配置指南:打造个性化AI开发工作流
  • 构建一个基于 TD3 (Twin Delayed DDPG) 算法的永磁同步电机(PMSM)电流环控制系统
  • 如何永久禁用Windows Defender?开源工具Defender Control的3步解决方案
  • 3步解决C盘爆红难题:开源神器WindowsCleaner完全使用指南
  • 原神成就数据自动化导出工具:YaeAchievement技术架构与实现原理深度解析
  • AI智能体任务规范:从概念到实践,构建可靠的多步骤自动化工作流
  • AI编程助手实战:通过Cursor练习项目掌握高效开发技巧
  • 阿里云2026年5月Hermes Agent/OpenClaw搭建解析,百炼token Plan配置指南
  • ARM Fast Models Trace组件:原理、功能与调试实践
  • ipasim技术解密:Windows平台iOS应用模拟器的架构剖析与实战指南