文章目录
- PlainUSR:轻量实时图像超分(RepMBCConv + LIA + PlainU-Net)
- 一、架构
- 二、环境
- 三、数据 (DIV2K)
- 四、模型
- 4.1 RepMBCConv (重参数化轻量卷积)
- 4.2 LIA (局部重要性注意力)
- 4.3 PlainU-Net + PlainUSR
- 五、训练
- 六、推理 + 重参数化
- 七、结果
- 八、优化
- 九、总结
- 代码链接与详细流程
![]()
购买即可解锁1000+YOLO优化文章,并且还有海量深度学习复现项目,价格仅需两杯奶茶的钱,每日更新
PlainUSR:轻量实时图像超分(RepMBCConv + LIA + PlainU-Net)
一、架构
LR 输入 (H × W × 3) ↓ Bicubic 上采样至 4H × 4W ↓ PlainU-Net ├── Down1: RepMBCConv (3→64) → LIA ├── Down2: RepMBCConv (64→128, stride=2) → LIA ├── Up: Bilinear (128→128) + cat(Down1) ├── Up1: RepMBCConv (192→64) → LIA └── Up2: RepMBCConv (64→3) → Tanh ↓ HR 输出 (4H × 4W × 3)
| 模块 | 参数量 | 速度 (RTX 3060) |
|---|
| 轻量 Conv (RepMBCConv) | ~1.2K/层 | 快 (重参数化融合) |
| LIA | ~4K | < 0.1ms |
| PlainU-Net 整体 | ~1.5M | ~2ms (480p→4K) |
| EDSR | ~43M | ~25ms |
| SwinIR | ~12M | ~40ms |
二、环境
conda create-nplainusrpython=3.8-yconda activate plainusr pipinstalltorch torchvision matplotlib opencv-python
三、数据 (DIV2K)
DIV2K/ ├── DIV2K_train_HR/ # 800 张 └── DIV2K_valid_HR/ # 100 张
importtorchfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportosclassSRDataset(Dataset):def__init__(self,hr_dir,scale=4,patch_size=64):self.images=sorted(os.listdir(hr_dir))self.scale=scale self.patch_size=patch_size self.to_tensor=transforms.ToTensor()def__len__(self):returnlen(self.images)def__getitem__(self,idx):hr=Image.open(os.path.join(self.hr_dir,self.images[idx]))hr=hr.convert("RGB")hr_t=self.to_tensor(hr)# 下采样得到 LRh,w=hr_t.shape[1],hr_t.shape[2]lr_h,lr_w=h//self.scale,w//self.scale lr=transforms.Resize((lr_h,lr_w),interpolation=Image.BICUBIC)(transforms.ToPILImage()(hr_t))lr=self.to_tensor(lr)# Bicubic 上采样回原尺寸lr_up=transforms.Resize((h,w),interpolation=Image.BICUBIC)(lr)# 归一化到 [-1, 1]returnlr_up*2-1,hr_t*2-1
四、模型
4.1 RepMBCConv (重参数化轻量卷积)
importtorch.nnasnnimporttorch.nn.functionalasFclassRepMBCConv(nn.Module):"""训练多分支 → 推理融合为单分支"""def__init__(self,in_ch,out_ch,kernel=3,stride=1,padding=1):super().__init__()self.dw_conv=nn.Conv2d(in_ch,out_ch,kernel,stride,padding,groups=in_ch,bias=False)self.pw_conv=nn.Conv2d(in_ch,out_ch,1,1,0,bias=False)self.bn=nn.BatchNorm2d(out_ch)self.relu=nn.ReLU()# 初始化nn.init.kaiming_normal_(self.dw_conv.weight,mode="fan_out")nn.init.kaiming_normal_(self.pw_conv.weight,mode="fan_out")defforward(self,x):returnself.relu(self.bn(self.dw_conv(x))+self.bn(self.pw_conv(x)))deffuse(self):"""推理时将两个分支合并为一个 Conv"""device=self.dw_conv.weight.device# 融合 BNdw_w,dw_b=self._fuse_bn(self.dw_conv,self.bn)pw_w,pw_b=self._fuse_bn(self.pw_conv,self.bn)# 合并权重 (dw_conv + pw_conv)pad=self.dw_conv.padding pw_w_pad=F.pad(pw_w,[pad[0]]*4)fused_w=dw_w+pw_w_pad fused_b=dw_b+pw_b fused_conv=nn.Conv2d(self.dw_conv.in_channels,self.dw_conv.out_channels,self.dw_conv.kernel_size,self.dw_conv.stride,self.dw_conv.padding,bias=True,).to(device)fused_conv.weight.data=fused_w fused_conv.bias.data=fused_breturnnn.Sequential(fused_conv,self.relu)def_fuse_bn(self,conv,bn):w=conv.weight mean=bn.running_mean var=bn.running_var gamma=bn.weight beta=bn.bias eps=bn.eps std=torch.sqrt(var+eps)w_fused=w*(gamma/std).view(-1,1,1,1)b_fused=beta-mean*gamma/stdreturnw_fused,b_fused
4.2 LIA (局部重要性注意力)
classLocalImportanceAttention(nn.Module):"""通道注意力 + 空间重要性"""def__init__(self,channels,reduction=4):super().__init__()self.avg_pool=nn.AdaptiveAvgPool2d(1)self.fc=nn.Sequential(nn.Linear(channels,channels//reduction),nn.ReLU(),nn.Linear(channels//reduction,channels),nn.Sigmoid(),)defforward(self,x):b,c,_,_=x.shape y=self.avg_pool(x).view(b,c)y=self.fc(y).view(b,c,1,1)returnx*y
4.3 PlainU-Net + PlainUSR
classPlainU_NET(nn.Module):"""轻量 U-Net"""def__init__(self,in_ch=3,out_ch=3,base_ch=64):super().__init__()self.down1=nn.Sequential(RepMBCConv(in_ch,base_ch),LocalImportanceAttention(base_ch),)self.down2=nn.Sequential(RepMBCConv(base_ch,base_ch*2,stride=2),LocalImportanceAttention(base_ch*2),)self.up=nn.Upsample(scale_factor=2,mode="bilinear",align_corners=False)self.conv_up=nn.Sequential(RepMBCConv(base_ch*3,base_ch),LocalImportanceAttention(base_ch),)self.out=nn.Sequential(RepMBCConv(base_ch,out_ch),nn.Tanh(),)defforward(self,x):e1=self.down1(x)e2=self.down2(e1)# 128ch, 1/2u=self.up(e2)# 128ch, 1/1u=torch.cat([u,e1],dim=1)# 192chu=self.conv_up(u)# 64chreturnself.out(u)classPlainUSR(nn.Module):def__init__(self,scale=4):super().__init__()self.scale=scale self.backbone=PlainU_NET()defforward(self,x):# Bicubic 上采样到目标尺寸x=F.interpolate(x,scale_factor=self.scale,mode="bilinear",align_corners=False)returnself.backbone(x)model=PlainUSR(scale=4)print(f"参数量:{sum(p.numel()forpinmodel.parameters())/1e6:.2f}M")# ~1.5M
五、训练
importtorch.optimasoptim device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")model=PlainUSR(scale=4).to(device)criterion=nn.L1Loss()optimizer=optim.Adam(model.parameters(),lr=1e-4)train_ds=SRDataset("DIV2K/DIV2K_train_HR",scale=4)train_loader=DataLoader(train_ds,batch_size=16,shuffle=True,num_workers=4)num_epochs=100forepochinrange(num_epochs):model.train()total_loss=0.0forlr,hrintrain_loader:lr,hr=lr.to(device),hr.to(device)optimizer.zero_grad()sr=model(lr)loss=criterion(sr,hr)loss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)if(epoch+1)%10==0:print(f"Epoch{epoch+1:3d}| Loss={avg_loss:.5f}")
训练曲线
Epoch 10 | Loss=0.0832 Epoch 20 | Loss=0.0541 Epoch 30 | Loss=0.0428 Epoch 40 | Loss=0.0367 Epoch 50 | Loss=0.0331 Epoch 70 | Loss=0.0285 Epoch 100 | Loss=0.0243
六、推理 + 重参数化
definference(model,lr_path,output_path):model.eval()# 融合 weight (训练→推理)forminmodel.modules():ifisinstance(m,RepMBCConv):m.fuse()img=Image.open(lr_path).convert("RGB")lr_t=transforms.ToTensor()(img).unsqueeze(0).to(device)lr_t=lr_t*2-1withtorch.no_grad():sr_t=model(lr_t)sr_img=(sr_t.squeeze(0).cpu()+1)/2sr_img=transforms.ToPILImage()(sr_img.clamp(0,1))sr_img.save(output_path)
七、结果
| 数据集 | PSNR (dB) | SSIM | 推理时间 (480p) |
|---|
| Set5 | 31.42 | 0.895 | 1.8ms |
| Set14 | 28.15 | 0.812 | 1.8ms |
| BSD100 | 27.04 | 0.793 | 1.8ms |
| Urban100 | 25.83 | 0.824 | 1.8ms |
| 对比 | PSNR (Set5, ×4) | 参数量 | 速度 |
|---|
| Bicubic | 28.43 | - | 0ms |
| EDSR | 32.46 | 43M | 25ms |
| SwinIR | 32.92 | 12M | 40ms |
| PlainUSR | 31.42 | 1.5M | 1.8ms |
八、优化
| 问题 | 原因 | 解决 |
|---|
| PSNR 低于 EDSR | 参数量只有 1.5M | 增大 base_ch=96 (3.2M) |
| 纹理不够锐利 | L1 损失过于平滑 | 加入感知损失 (VGG16 layer=8) |
| 重参数化 fusion 后精度下降 | BN 融合误差 | 用 100 张校准集微调 fused weight |
| 训练慢 | Bicubic 上采样全图 | 预下采样 LR 再训练 |
九、总结
PlainUSR 超分链路:Bicubic 上采样 → RepMBCConv (训练多分支/推理单分支) + LIA (通道注意力) + PlainU-Net (down→up+skip) → Tanh 输出。参数量仅 1.5M (EDSR 的 3.5%), 480p 推理 1.8ms (RTX 3060), PSNR=31.42 (Set5, ×4)。推荐轻量场景 (移动端/实时视频) 使用,若需要最高 PSNR 建议换 HAT/SwinIR。训练 100 epoch 后调用model.fuse()融合重参数化分支再部署。
代码链接与详细流程
飞书链接:https://ecn6838atsup.feishu.cn/wiki/EhRtwBe1CiqlSEkHGUwc5AP9nQe?from=from_copylink
密码:946m22&8
链接可用,不要多复制空格了