保姆级教程:用DF2K和OST数据集复现Real-ESRGAN训练全流程(附超参数避坑点)
从零实现Real-ESRGAN:DF2K+OST数据集训练全流程与调参实战
当我在实验室第一次尝试复现Real-ESRGAN论文时,面对官方文档中简略的训练说明和庞杂的数据集要求,整整两周时间都在解决各种环境配置和参数调试问题。本文将分享从数据准备到模型收敛的完整实操经验,特别是那些官方文档没有明确说明的"坑点"。
1. 环境配置与数据准备
1.1 硬件与基础环境
建议使用至少24GB显存的NVIDIA GPU(如3090或4090),因为Real-ESRGAN的GAN阶段训练对显存需求较高。以下是基础环境配置步骤:
conda create -n realesrgan python=3.8 conda activate realesrgan pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html git clone https://github.com/xinntao/Real-ESRGAN.git cd Real-ESRGAN pip install -r requirements.txt注意:PyTorch版本过高可能导致某些自定义算子不兼容,建议严格使用1.10.x版本
1.2 数据集获取与整理
DF2K数据集实际包含两个子集:
- DIV2K:800张2K分辨率图像
- Flickr2K:2650张Flickr上的高质图片
OST数据集则包含10324张按场景分类的图像(天空、水体、建筑等)。下载后建议按以下结构组织:
datasets/ ├── DF2K/ │ ├── DIV2K_train_HR/ │ └── Flickr2K_HR/ └── OST/ ├── sky/ ├── water/ └── ...使用以下命令快速校验数据完整性:
import os print(f"DIV2K图像数量: {len(os.listdir('datasets/DF2K/DIV2K_train_HR'))}") # 应输出800 print(f"Flickr2K图像数量: {len(os.listdir('datasets/DF2K/Flickr2K_HR'))}") # 应输出26502. 数据预处理关键步骤
2.1 图像裁剪策略
原始论文使用随机裁剪512x512 patches的训练方式。实际操作中发现两个优化点:
- 对纹理丰富的区域(如建筑立面)适当增大裁剪尺寸
- 对平滑区域(如天空)可减小尺寸以增加batch size
改进后的裁剪脚本:
def adaptive_crop(img, min_size=512, max_size=768): std = img.std() # 计算局部方差 crop_size = min(max_size, max(min_size, int(512 + (std/50)*256))) return random_crop(img, crop_size)2.2 退化模型参数调整
Real-ESRGAN采用二阶退化过程:
- 模糊+噪声+JPEG压缩
- 分辨率降低+传感器噪声
建议修改options/train_realesrgan.yml中的退化参数:
degradation_params: blur_kernel_size: [7, 9, 11] # 原为[21,23,25] jpeg_quality: [30, 50] # 原为[30,60]实验表明:适度减小模糊核尺寸能更好平衡真实感和锐度
3. 两阶段训练实战细节
3.1 PSNR导向阶段(L1 Loss)
关键配置参数对照表:
| 参数 | 论文推荐值 | 实际有效值 | 说明 |
|---|---|---|---|
| batch_size | 16 | 12-14 | 受显存限制可调小 |
| lr | 2e-4 | 1.5e-4 | 初始学习率 |
| lr_decay | - | [300k,600k] | 论文未明确说明 |
训练启动命令:
python train.py -opt options/train_realesrgan.yml \ --auto_resume \ --debug常见问题解决:
- Loss震荡大:尝试减小batch size或增加gradient_accumulation_steps
- 显存不足:设置
--num_workers 2减少数据加载压力
3.2 GAN训练阶段调参技巧
生成器与判别器的学习率比例是成功关键。经过多次实验验证的最佳配置:
# 修改train_realesrgan.yml g_optim: lr: !!float 1e-4 d_optim: lr: !!float 5e-4 # 原始值为1e-4 loss: pixel_weight: 1.0 perceptual_weight: 1.0 gan_weight: 0.05 # 原始值为0.1训练中建议监控以下指标:
- 生成器loss稳定下降
- 判别器accuracy维持在0.6-0.7
- 验证集PSNR波动范围<0.5dB
4. 模型微调与效果提升
4.1 自适应损失权重调整
在训练后期(约300k迭代后),可动态调整损失权重:
def adjust_loss_weights(iter): if iter < 100000: return 1.0, 1.0, 0.01 # L1, percep, GAN elif iter < 300000: return 0.8, 1.2, 0.05 else: return 0.5, 1.5, 0.14.2 模型融合策略
将不同checkpoint的模型进行加权融合,能显著提升最终效果:
def model_fusion(checkpoints, weights): net = RealESRGAN() params = [] for ckpt in checkpoints: net.load_state_dict(torch.load(ckpt)) params.append([p.data.clone() for p in net.parameters()]) # 加权平均 for i, p in enumerate(net.parameters()): p.data = sum(w*param[i] for w, param in zip(weights, params)) return net典型组合方案:
- 300k迭代模型(权重0.3)
- 400k迭代模型(权重0.4)
- 最终模型(权重0.3)
5. 效果评估与常见问题
5.1 定量评估指标对比
在DIV2K验证集上的典型结果:
| 模型阶段 | PSNR(dB) | SSIM | LPIPS |
|---|---|---|---|
| PSNR阶段 | 28.7 | 0.820 | 0.312 |
| GAN阶段 | 27.9 | 0.805 | 0.195 |
LPIPS值降低表明感知质量提升,尽管PSNR略有下降
5.2 典型问题排查指南
问题1:训练后期出现伪影
- 检查数据增强中的退化参数
- 降低GAN loss权重
- 尝试添加gradient penalty
问题2:细节过度平滑
- 减小模糊核尺寸
- 增加perceptual loss权重
- 在数据预处理中保留更多高频信息
问题3:训练不稳定
# 添加梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)在多次复现过程中,最耗时的往往不是训练本身,而是数据预处理和参数调试阶段。建议在正式开始大规模训练前,先用小规模数据(如100张图)快速验证整个pipeline的正确性。
