067、高效训练技巧:梯度检查点、混合精度与分布式
昨天深夜调一个ViT-Base的扩散模型训练,显存直接爆到16G顶满,batch_size降到4都跑不动。盯着nvidia-smi里那刺眼的“Out of Memory”,我意识到又到了在速度与内存之间走钢丝的时候。这年头训扩散模型,没点压榨硬件的本事,连实验都跑不起来。
梯度检查点:用时间换空间的老把戏
很多新人看到显存不够第一反应是减batch_size,但有些任务batch_size太小收敛都成问题。这时候该祭出梯度检查点(Gradient Checkpointing)了。
# 常规训练时,前向传播的中间激活值全存着,等着反向传播用# 显存占用随网络深度线性增长,太奢侈了# 开启检查点后,只保存部分层的激活,其他的需要时重新算fromtorch.utils.checkpointimportcheckpointdefforward_with_checkpoint