当前位置: 首页 > news >正文

大模型教我成为大模型算法工程师之day8: 优化器与训练技巧

Day 8: 优化器与训练技巧

摘要:设计好了神经网络架构只是第一步,如何让它“学”好则是另一个关键挑战。本文深入探讨深度学习中的优化器演进(从SGD到AdamW)、学习率调度策略、关键的归一化技术(BN、LN、RMSNorm)以及防止过拟合的正则化手段。


1. 优化器 (Optimizers)

优化器的作用是根据计算出的梯度来更新模型的权重,以最小化损失函数。

1.1 SGD 与 Momentum

SGD (Stochastic Gradient Descent)是最基础的优化算法,每次只随机抽取一部分样本(Batch)计算梯度并更新。
θ t + 1 = θ t − η ⋅ ∇ J ( θ t ) \theta_{t+1} = \theta_t - \eta \cdot \nabla J(\theta_t)θt+1=θtηJ(θt)
其中η \etaη是学习率。

通俗解释:梯度震荡
想象你在滑雪下山,地形是一个狭长的峡谷,左右坡度很陡,但沿着峡谷向下的坡度很缓。

  • SGD的困境:你站在峡谷一侧,SGD只看脚下,觉得“左边好陡”,于是用力向右冲;冲到对面后又觉得“右边好陡”,于是用力向左冲。结果就是你在峡谷两壁之间来回“剧烈横跳”(震荡),大部分力气花在左右移动上,沿着峡谷向下前进的速度反而很慢。
  • Momentum的作用:引入“动量”模拟惯性。当你左右横跳时,惯性会抵消一部分横向的力,保留更多纵向(沿峡谷走向)的速度,让你能平滑地滑向谷底。

Momentum (动量):为了解决SGD在由于梯度方向震荡导致收敛慢的问题,引入了“动量”概念,模拟物理中的惯性。
v t + 1 = γ v t + η ∇ J ( θ t ) v_{t+1} = \gamma v_t + \eta \nabla J(\theta_t)vt+1=γvt+ηJ(θt)
θ t + 1 = θ t − v t + 1 \theta_{t+1} = \theta_t - v_{t+1}θt+1=θtvt+1
动量项γ \gammaγ通常设为 0.9。

1.2 Adam (Adaptive Moment Estimation)

Adam 结合了 Momentum (一阶动量) 和 RMSProp (二阶动量/自适应学习率) 的优点。它为每个参数计算独立的自适应学习率。

  • 一阶动量(均值):m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1-\beta_1) g_tmt=β1mt1+(1β1)gt
  • 二阶动量(方差):v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2vt=β2vt1+(1β2)gt2
  • 偏差修正:m ^ t = m t / ( 1 − β 1 t ) \hat{m}_t = m_t / (1-\beta_1^t)m^t=mt/(1β1t),v ^ t = v t / ( 1 − β 2 t ) \hat{v}_t = v_t / (1-\beta_2^t)v^t=vt/(1β2t)
  • 更新:θ t + 1 = θ t − η m ^ t v ^ t + ϵ \theta_{t+1} = \theta_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}θt+1=θtηv^t+ϵm^t

1.3 AdamW

AdamW是目前大模型训练的主流选择。它解决了 Adam 中权重衰减(L2正则)实现不正确的问题。在 AdamW 中,权重衰减(Weight Decay)直接应用于权重更新步骤,而不是混入梯度计算中,从而实现了与学习率的解耦。

💻 代码实践:PyTorch优化器选择

importtorchimporttorch.nnasnnimporttorch.optimasoptim model=nn.Linear(10,2)# 1. SGD with Momentumoptimizer_sgd=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)# 2. Adamoptimizer_adam=optim.Adam(model.parameters(),lr=0.001)# 3. AdamW (推荐用于Transformer/大模型)optimizer_adamw=optim.AdamW(model.parameters(),lr=0.001,weight_decay=0.01)

2. 学习率调度 (Learning Rate Schedulers)

恒定的学习率往往无法达到最优解。通常策略是:前期Warmup(预热)以稳定梯度,后期Decay(衰减)以精细收敛

2.1 Warmup

在训练初期,由于梯度变化剧烈,使用较大的学习率容易导致模型不稳定。Warmup 策略是在最初的几步(如前5% steps)将学习率从 0 线性增加到预设的最大值。

2.2 Cosine Annealing (余弦退火)

学习率随训练步数按余弦函数曲线下降。相比于阶梯式下降(Step Decay),余弦退火更加平滑,且往往能获得更好的泛化能力。

💻 代码实践:Scheduler

fromtransformersimportget_cosine_schedule_with_warmup# 假设总步数为 1000,预热步数为 100optimizer=optim.AdamW(model.parameters(),lr=1e-3)scheduler=get_cosine_schedule_with_warmup(optimizer,num_warmup_steps=100,num_training_steps=1000)# 在训练循环中# optimizer.step()# scheduler.step()

3. 归一化技术 (Normalization)

归一化旨在解决Internal Covariate Shift (ICS)问题,使各层输入的分布保持稳定,从而加速收敛并允许使用更大的学习率。

通俗解释:Internal Covariate Shift
把深度网络看作一个“流水线工厂”。第 2 层(工人B)习惯处理第 1 层(工人A)传过来的“标准件”。
但随着训练进行,工人 A 的参数在变,传给 B 的产品特性(数据分布)也在不停地变。工人 B 就不得不一直重新适应 A 的变化,导致整个工厂效率极低。
归一化(BN/LN)就像在 A 和 B 之间放了一个“质检员”,不管 A 产出什么,都强行把它标准化(均值0方差1)再给 B。这样 B 就能稳定工作了。

方法适用场景维度 (Input: N, C, H, W)描述
Batch Norm (BN)CNN (CV任务)对 N, H, W 归一化依赖 Batch Size,训练/推理行为不同
Layer Norm (LN)RNN/Transformer (NLP)对 C, H, W 归一化独立于 Batch Size,对序列长度不敏感
RMSNormLLM (如 LLaMA)同 LNLN 的简化版,去除了均值中心化,仅保留缩放,计算更高效

RMSNorm 公式

相比 LayerNorm,RMSNorm (Root Mean Square Layer Normalization) 省略了减去均值的步骤:
x ˉ i = x i RMS ( x ) g i , where RMS ( x ) = 1 n ∑ j = 1 n x j 2 + ϵ \bar{x}_i = \frac{x_i}{\text{RMS}(x)} g_i, \quad \text{where } \text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{j=1}^n x_j^2 + \epsilon}xˉi=RMS(x)xigi,whereRMS(x)=n1j=1nxj2+ϵ


4. 正则化 (Regularization)

为了防止模型过拟合(在训练集表现好,测试集表现差),需要引入正则化。

4.1 Dropout

在训练过程中,随机将一部分神经元的输出置为 0。这相当于训练了无数个子网络的集成,迫使网络不过分依赖某些特定的特征。

类比理解
Dropout 的机制与随机森林 (Random Forest)非常相似。

  • 训练时:每次随机关掉一部分神经元,相当于每次都在训练一个不同的“残缺版”子网络。
  • 推理时:所有神经元全开,相当于把这成百上千个子网络的预测结果做了“加权平均”。
    这种隐式的集成学习 (Ensemble Learning)有效降低了模型的方差,提升了泛化能力。

注:现代大模型训练中,为防止破坏特征,有时会减少Dropout的使用或仅在特定位置使用。

4.2 Weight Decay (L2 正则)

在损失函数中加入权重的平方和惩罚项:L = L d a t a + λ ∣ ∣ w ∣ ∣ 2 L = L_{data} + \lambda ||w||^2L=Ldata+λ∣∣w2。这限制了权重的大小,防止模型过于复杂。

4.3 数据增强 (Data Augmentation)

通过对训练数据进行变换(如图片的翻转、裁剪、颜色变换,文本的掩码、回译等)来增加数据多样性,是提升模型鲁棒性最直接有效的方法。


5. 总结

在深度学习训练中,“炼丹”技巧往往和模型架构一样重要:

  1. 优化器:首选AdamW,它是目前 CV 和 NLP 领域的通用选择。
  2. 学习率:配合Warmup + Cosine Decay策略,能显著提升收敛效果。
  3. 归一化:CNN 用 BN,Transformer/RNN 用 LN 或RMSNorm
  4. 正则化:合理使用 Weight Decay 和 Dropout 防止过拟合。

掌握这些组件的原理与搭配,是训练高性能模型的基础。


参考资料

  • Decoupled Weight Decay Regularization (AdamW Paper)
  • Root Mean Square Layer Normalization
  • PyTorch Optimization Documentation
http://www.jsqmd.com/news/79662/

相关文章:

  • Java毕设项目:基于springboot成都旅游网四季成都、特色文化(源码+文档,讲解、调试运行,定制等)
  • League Akari:6个实用功能让你告别繁琐操作,轻松上分
  • api vs jsp 绑定风格
  • 理解 Proxy 原理及如何拦截 Map、Set 等集合方法调用实现自定义拦截和日志——含示例代码解析
  • Java毕设项目:基于springboot厨具厂产品在线销售系统设计与实现小程序(源码+文档,讲解、调试运行,定制等)
  • Java毕设项目:基于springboot二手商品网站(源码+文档,讲解、调试运行,定制等)
  • 详解 Gitee/GitHub 中 HTTPS/SSH 方式数据库仓库创建与本地连接
  • 第五十七篇-ComfyUI+V100-32G+安装SD1.5
  • 突破实时视频生成瓶颈:Krea Realtime 14B模型革新文本到视频技术
  • systemd-resolved.service实验实战3
  • 哔哩下载姬:5个实用技巧让你的B站视频下载效率翻倍
  • Windows右键菜单终极优化指南:从卡顿到流畅的深度解析
  • 腾讯优图实验室开源Youtu-Embedding文本表示模型,赋能企业级AI应用创新
  • SAM3在医疗影像里“指鹿为马”?MedSAM3来了——文本一句话,精准分割病灶
  • Java毕设项目:基于SpringBoot网上超市的设计与实现基于springboot超市在线销售系统的设计与实现(源码+文档,讲解、调试运行,定制等)
  • 小学娃近视防控不费妈!这款眼调节训练灯,学习护眼一步到位
  • 无人机看地面小目标总“眼瞎”?MambaRefine-YOLO来救场:双模态融合+高效检测,精度直接拉满!
  • QDialog-基础讲解
  • 【异常】豆包TTS语音合成常见报错及SSML代码实现解决方案
  • Java 大视界 -- Java 大数据在智能教育学习成果评估体系完善与教育质量提升中的深度应用(434)
  • 【项目实战】Vercel 是一个让你的网站“瞬间上线”的云平台。Vercel 现在确实是技术圈的“当红炸子鸡”,尤其是在个人博客和前端开发领域。
  • 【异常】Coze提示WorkflowEventError(errorCode=5000, errorMessage=The request parameter is illegal, see:
  • Python-2. Python语言初识-教学设计
  • IC卡门禁读卡器是一款高性能、多协议兼容的智能识别终端,专为门禁、梯控、闸机等场景设计。它同时支持125KHz低频协议和13.56MHz高频协议,具备极强的环境适应性,可在金属表面(建议开孔安装)
  • 02、打不开某个网站
  • 基于SpringBoot + Vue的企业培训与绩效评估系统
  • 为什么近视的孩子更推荐眼调节训练灯?不是护眼灯不好,而是需求不一样!
  • 基于SpringBoot + Vue的健身房管理系统
  • 每个神经元负责提取不同特征?还是每层神经元负责提取不同特征?
  • WPS Office镜像大全