PyTorch 训练稳定性:梯度爆炸前通常有征兆
PyTorch 训练稳定性:梯度爆炸前通常有征兆
一、训练崩掉不是突然发生的
深度学习训练中,loss 变成 NaN、梯度爆炸、显存异常和指标剧烈震荡,看起来像突然发生。实际上,在崩掉之前通常有征兆:梯度范数上升、学习率过高、激活值异常、数据 batch 分布突变、混合精度 loss scale 不稳定。训练稳定性要靠提前监控。
如果只盯最终 loss,就像只看天气结果不看云层变化。模型训练里的“天象”是梯度、权重、激活和数据分布。记录这些信号,才能在爆炸前收手。
二、监控链路:数据、前向、反向一起看
flowchart TD A[训练数据] --> B[Forward] B --> C[Loss] C --> D[Backward] D --> E[梯度范数] E --> F[Optimizer Step] F --> G[稳定性监控]训练稳定性至少监控 loss、learning rate、gradient norm、参数范数、NaN 数量和 batch 数据统计。对于混合精度训练,还要记录 loss scale 和 overflow 次数。只记录 loss 曲线,信息太少。
数据问题也很常见。某些 batch 标签异常、输入过长、空样本、极端值,都可能造成训练抖动。训练崩掉时,不要只怪优化器,先把出问题的 batch 保存下来。
三、代码示例:记录梯度范数
下面是一个简化的 PyTorch 梯度范数记录。
import torch def grad_norm(model): total = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.detach().data.norm(2) total += param_norm.item() ** 2 return total ** 0.5 loss.backward() norm = grad_norm(model) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()梯度裁剪能缓解爆炸,但不是万能药。如果每一步都在强行裁剪,说明学习率、模型结构或数据可能有问题。裁剪是安全带,不是发动机。
记录时要结合 step、样本 ID 和学习率。否则看到梯度异常,也不知道是哪批数据触发的。可复现的异常才好修。
四、排查顺序:先数据,再学习率,再结构
遇到 NaN,先检查数据是否有 NaN、Inf、空文本或异常标签。再检查学习率是否过高、warmup 是否太短、混合精度是否 overflow。最后再怀疑模型结构。顺序很重要,别一上来重写网络。
如果使用分布式训练,要记录每个 rank 的异常。某个 rank 数据异常,可能导致全局训练失败。不要只看 rank0 日志。训练集群里,最安静的错误往往藏在非主进程。
最后,保存崩溃前 checkpoint 和 batch。这样可以在小环境复现,而不是重新跑几小时等它再次爆炸。训练稳定性工程,靠的是证据留存。
还要关注优化器状态。只保存模型参数,有时无法复现继续训练后的行为,因为 Adam 的动量状态、学习率调度器状态和 AMP scaler 都会影响下一步。稳定性问题发生时,完整 checkpoint 比单独权重更有价值。
如果训练经常在同一阶段崩溃,可以把那一段单独缩小复现。用更小数据、更短 step 和固定 seed 重放,排查速度会快很多。不要每次都从头跑完整训练等故障出现。
五、总结
PyTorch 训练稳定性要提前监控梯度、参数、loss scale 和数据 batch。梯度爆炸前通常有征兆,NaN 也常有来源。先查数据,再查学习率和混合精度,最后再动结构。炼丹也要看仪表盘。
