Pytorch图像去噪实战(二十七):EMA指数滑动平均实战,让图像去噪模型推理结果更稳定
Pytorch图像去噪实战(二十七):EMA指数滑动平均实战,让图像去噪模型推理结果更稳定
一、问题场景:训练后期loss波动,保存哪个模型都不放心
训练图像去噪模型时,经常会遇到这种情况:
- epoch 60 效果不错
- epoch 70 loss更低,但图像更糊
- epoch 80 指标波动
- epoch 90 局部伪影变多
我一开始的做法是保存多个 checkpoint,然后人工挑选。
但这样很麻烦,也不稳定。
后来我在训练中加入 EMA,也就是指数滑动平均权重,推理效果通常更稳定。
二、什么是EMA?
EMA 全称 Exponential Moving Average。
它维护一份模型权重的平滑版本:
ema_weight = decay * ema_weight + (1 - decay) * current_weight可以理解为:
不完全相信当前一步的权重,而是保留一份长期平均后的权重。
三、为什么EMA适合图像去噪?
图像去噪对模型权重波动很敏感。
一点点变化可能导致:
<