SRCNN超分辨率实战:在Colab上用PyTorch训练自己的图像修复模型(附数据集处理技巧)
SRCNN超分辨率实战:在Colab上用PyTorch训练自己的图像修复模型
当你在社交媒体上看到一张模糊的老照片,或是从监控视频中截取的关键帧分辨率太低时,是否想过用AI技术让它们重获新生?超分辨率技术正是为解决这类问题而生。本文将带你从零开始,在Google Colab的免费GPU环境下,用PyTorch实现经典的SRCNN模型,并处理你自己的图片数据集。
1. 环境准备与数据管理
在Colab中运行深度学习项目,首先要解决的是数据存储问题。与本地开发不同,Colab的临时存储空间会在会话结束后清空,因此我们需要合理利用Google Drive进行持久化存储。
from google.colab import drive drive.mount('/content/drive')挂载成功后,建议在Drive中创建如下目录结构:
SRCNN_Project/ ├── data/ │ ├── raw/ # 存放原始图像 │ ├── processed/ # 存放处理后的h5文件 ├── outputs/ # 训练输出 ├── logs/ # TensorBoard日志对于数据集选择,除了论文中提到的91-image和Set5/Set14,我们还可以使用以下更适合初学者的替代方案:
- DIV2K:包含800张训练图像和100张验证图像
- BSD500:伯克利分割数据集,含500张自然图像
- Flickr2K:2650张高分辨率图像
提示:使用小规模数据集时,建议将图像裁剪为256x256或128x128的patch,这样可以增加样本数量并减少显存消耗。
2. 高效数据预处理技巧
原始论文要求将图像转换为h5格式,这对Colab环境尤为重要——频繁读取小文件会显著降低IO性能。我们改进的prepare.py脚本增加了以下功能:
def create_h5_file(image_paths, output_path, patch_size=33, stride=14, scale=3): h5_file = h5py.File(output_path, 'w') lr_patches = [] hr_patches = [] for image_path in tqdm(image_paths): hr = cv2.imread(image_path, cv2.IMREAD_COLOR) hr = cv2.cvtColor(hr, cv2.COLOR_BGR2RGB) lr = cv2.resize(hr, (hr.shape[1]//scale, hr.shape[0]//scale), interpolation=cv2.INTER_CUBIC) # 生成patch for i in range(0, hr.shape[0]-patch_size+1, stride): for j in range(0, hr.shape[1]-patch_size+1, stride): hr_patch = hr[i:i+patch_size, j:j+patch_size] lr_patch = lr[i//scale:(i+patch_size)//scale, j//scale:(j+patch_size)//scale] lr_patch = cv2.resize(lr_patch, (patch_size, patch_size), interpolation=cv2.INTER_CUBIC) lr_patches.append(lr_patch.transpose(2,0,1)) hr_patches.append(hr_patch.transpose(2,0,1)) # 转换为numpy数组并保存 h5_file.create_dataset('lr', data=np.array(lr_patches, dtype=np.float32)/255.) h5_file.create_dataset('hr', data=np.array(hr_patches, dtype=np.float32)/255.) h5_file.close()关键参数说明:
| 参数 | 推荐值 | 作用 |
|---|---|---|
| patch_size | 33 | 训练patch的大小 |
| stride | 14 | 滑动窗口步长 |
| scale | 3 | 超分辨率放大倍数 |
3. 模型训练与优化
SRCNN的PyTorch实现虽然简单,但在Colab环境中训练时仍有多个优化点需要注意:
class SRCNN(nn.Module): def __init__(self): super(SRCNN, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4) self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0) self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.conv3(x) return x训练时的实用技巧:
学习率策略:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=0.5, patience=5, verbose=True)混合精度训练(减少显存占用):
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): preds = model(inputs) loss = criterion(preds, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()早停机制:
if epoch - best_epoch > 20: # 连续20轮未提升 print("Early stopping triggered") break
4. 自定义图像处理实战
训练完成后,我们需要一个灵活的测试脚本处理各种来源的图像:
def process_custom_image(model, image_path, scale=3, device='cuda'): # 支持多种图像格式 img = Image.open(image_path).convert('RGB') original_size = img.size # 调整尺寸为scale的整数倍 new_width = (img.width // scale) * scale new_height = (img.height // scale) * scale if new_width != img.width or new_height != img.height: img = img.resize((new_width, new_height), Image.BICUBIC) # 生成低分辨率版本 lr = img.resize((new_width//scale, new_height//scale), Image.BICUBIC) lr = lr.resize((new_width, new_height), Image.BICUBIC) # 上采样 # 转换到YCbCr色彩空间 ycbcr = lr.convert('YCbCr') y, cb, cr = ycbcr.split() # 处理Y通道 y_tensor = torch.from_numpy(np.array(y, dtype=np.float32)/255.) y_tensor = y_tensor.unsqueeze(0).unsqueeze(0).to(device) with torch.no_grad(): pred_y = model(y_tensor).clamp(0, 1) # 合并通道 pred_y = pred_y[0,0].cpu().numpy() * 255. pred_y = Image.fromarray(pred_y.astype(np.uint8), mode='L') cb = cb.resize(pred_y.size, Image.BICUBIC) cr = cr.resize(pred_y.size, Image.BICUBIC) result = Image.merge('YCbCr', [pred_y, cb, cr]).convert('RGB') return result.resize(original_size, Image.BICUBIC)常见问题解决方案:
- 边缘伪影:在测试时对图像进行镜像padding
- 色彩失真:确保在YCbCr空间只增强Y通道
- 大图像内存不足:使用滑动窗口分块处理
5. 模型部署与性能提升
虽然SRCNN结构简单,但我们仍可以通过以下方式提升其实用性:
模型量化(减小模型体积):
quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Conv2d}, dtype=torch.qint8) torch.jit.save(torch.jit.script(quantized_model), 'srcnn_quantized.pt')ONNX导出(跨平台部署):
dummy_input = torch.randn(1, 3, 256, 256).to(device) torch.onnx.export(model, dummy_input, "srcnn.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})对于希望进一步提升效果的用户,可以考虑这些改进方向:
- 使用ESRGAN的感知损失替代MSE
- 添加通道注意力机制
- 采用渐进式超分辨率策略
