PyTorch自编码器训练崩溃怎么办?教你一招避坑
💓 博客主页:瑕疵的CSDN主页
📝 Gitee主页:瑕疵的gitee主页
⏩ 文章专栏:《热点资讯》
被PyTorch自编码器OOM坑到想删库跑路,终于找到根治方案
目录
昨晚写自编码器,跑着跑着突然报RuntimeError: CUDA out of memory。气得我直接把键盘砸在桌上。
数据是128x128的灰度图,batch size设64,跑两轮就崩了。
核心根源:GPU显存不够。自编码器处理图像时,batch size每大一倍,显存直接翻倍。我一开始以为是模型太深,结果发现是batch size设高了。更坑的是,没清理缓存,显存越积越多。
错误示范:
batch_size=64# 太大!直接爆显存train_loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)model=AutoEncoder().to('cuda')# 模型已加载到GPUoptimizer=Adam(model.parameters(),lr=0.001)forepochinrange(100):fordataintrain_loader:data=data.to('cuda')# 数据送GPUoutputs=model(data)# 前向传播loss=criterion(outputs,data)# 计算损失loss.backward()# 反向传播optimizer.step()# 更新参数# 没有清理GPU缓存!显存越用越多正确姿势:
batch_size=16# 从64砍到16,显存直降50%train_loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)model=AutoEncoder().to('cuda')optimizer=Adam(model.parameters(),lr=0.001)forepochinrange(100):fordataintrain_loader:data=data.to('cuda')# 关键:每次迭代清理GPU缓存torch.cuda.empty_cache()outputs=model(data)loss=criterion(outputs,data)loss.backward()optimizer.step()避坑总结:
- batch size别贪大。从8开始试,跑不动再调。
- 用
torch.cuda.memory_summary()实时看显存,别等崩了。 torch.cuda.empty_cache()是临时解,但比直接OOM强。- 模型太复杂?先用小网络跑通流程。
(左边是batch=64,显存爆到10G;右边batch=16,稳定在5G)
我测试过,改完batch size后,训练稳如老狗。
下次再写自编码器,先问自己:这batch size能塞进显存吗?
别等崩了才哭,早调早好。
