当前位置: 首页 > news >正文

PlainUSR:轻量实时图像超分(RepMBCConv + LIA + PlainU-Net)

文章目录

  • PlainUSR:轻量实时图像超分(RepMBCConv + LIA + PlainU-Net)
    • 一、架构
    • 二、环境
    • 三、数据 (DIV2K)
    • 四、模型
      • 4.1 RepMBCConv (重参数化轻量卷积)
      • 4.2 LIA (局部重要性注意力)
      • 4.3 PlainU-Net + PlainUSR
    • 五、训练
      • 训练曲线
    • 六、推理 + 重参数化
    • 七、结果
    • 八、优化
    • 九、总结
    • 代码链接与详细流程

购买即可解锁1000+YOLO优化文章,并且还有海量深度学习复现项目,价格仅需两杯奶茶的钱,每日更新

PlainUSR:轻量实时图像超分(RepMBCConv + LIA + PlainU-Net)

一、架构

LR 输入 (H × W × 3) ↓ Bicubic 上采样至 4H × 4W ↓ PlainU-Net ├── Down1: RepMBCConv (3→64) → LIA ├── Down2: RepMBCConv (64→128, stride=2) → LIA ├── Up: Bilinear (128→128) + cat(Down1) ├── Up1: RepMBCConv (192→64) → LIA └── Up2: RepMBCConv (64→3) → Tanh ↓ HR 输出 (4H × 4W × 3)
模块参数量速度 (RTX 3060)
轻量 Conv (RepMBCConv)~1.2K/层快 (重参数化融合)
LIA~4K< 0.1ms
PlainU-Net 整体~1.5M~2ms (480p→4K)
EDSR~43M~25ms
SwinIR~12M~40ms

二、环境

conda create-nplainusrpython=3.8-yconda activate plainusr pipinstalltorch torchvision matplotlib opencv-python

三、数据 (DIV2K)

DIV2K/ ├── DIV2K_train_HR/ # 800 张 └── DIV2K_valid_HR/ # 100 张
importtorchfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportosclassSRDataset(Dataset):def__init__(self,hr_dir,scale=4,patch_size=64):self.images=sorted(os.listdir(hr_dir))self.scale=scale self.patch_size=patch_size self.to_tensor=transforms.ToTensor()def__len__(self):returnlen(self.images)def__getitem__(self,idx):hr=Image.open(os.path.join(self.hr_dir,self.images[idx]))hr=hr.convert("RGB")hr_t=self.to_tensor(hr)# 下采样得到 LRh,w=hr_t.shape[1],hr_t.shape[2]lr_h,lr_w=h//self.scale,w//self.scale lr=transforms.Resize((lr_h,lr_w),interpolation=Image.BICUBIC)(transforms.ToPILImage()(hr_t))lr=self.to_tensor(lr)# Bicubic 上采样回原尺寸lr_up=transforms.Resize((h,w),interpolation=Image.BICUBIC)(lr)# 归一化到 [-1, 1]returnlr_up*2-1,hr_t*2-1

四、模型

4.1 RepMBCConv (重参数化轻量卷积)

importtorch.nnasnnimporttorch.nn.functionalasFclassRepMBCConv(nn.Module):"""训练多分支 → 推理融合为单分支"""def__init__(self,in_ch,out_ch,kernel=3,stride=1,padding=1):super().__init__()self.dw_conv=nn.Conv2d(in_ch,out_ch,kernel,stride,padding,groups=in_ch,bias=False)self.pw_conv=nn.Conv2d(in_ch,out_ch,1,1,0,bias=False)self.bn=nn.BatchNorm2d(out_ch)self.relu=nn.ReLU()# 初始化nn.init.kaiming_normal_(self.dw_conv.weight,mode="fan_out")nn.init.kaiming_normal_(self.pw_conv.weight,mode="fan_out")defforward(self,x):returnself.relu(self.bn(self.dw_conv(x))+self.bn(self.pw_conv(x)))deffuse(self):"""推理时将两个分支合并为一个 Conv"""device=self.dw_conv.weight.device# 融合 BNdw_w,dw_b=self._fuse_bn(self.dw_conv,self.bn)pw_w,pw_b=self._fuse_bn(self.pw_conv,self.bn)# 合并权重 (dw_conv + pw_conv)pad=self.dw_conv.padding pw_w_pad=F.pad(pw_w,[pad[0]]*4)fused_w=dw_w+pw_w_pad fused_b=dw_b+pw_b fused_conv=nn.Conv2d(self.dw_conv.in_channels,self.dw_conv.out_channels,self.dw_conv.kernel_size,self.dw_conv.stride,self.dw_conv.padding,bias=True,).to(device)fused_conv.weight.data=fused_w fused_conv.bias.data=fused_breturnnn.Sequential(fused_conv,self.relu)def_fuse_bn(self,conv,bn):w=conv.weight mean=bn.running_mean var=bn.running_var gamma=bn.weight beta=bn.bias eps=bn.eps std=torch.sqrt(var+eps)w_fused=w*(gamma/std).view(-1,1,1,1)b_fused=beta-mean*gamma/stdreturnw_fused,b_fused

4.2 LIA (局部重要性注意力)

classLocalImportanceAttention(nn.Module):"""通道注意力 + 空间重要性"""def__init__(self,channels,reduction=4):super().__init__()self.avg_pool=nn.AdaptiveAvgPool2d(1)self.fc=nn.Sequential(nn.Linear(channels,channels//reduction),nn.ReLU(),nn.Linear(channels//reduction,channels),nn.Sigmoid(),)defforward(self,x):b,c,_,_=x.shape y=self.avg_pool(x).view(b,c)y=self.fc(y).view(b,c,1,1)returnx*y

4.3 PlainU-Net + PlainUSR

classPlainU_NET(nn.Module):"""轻量 U-Net"""def__init__(self,in_ch=3,out_ch=3,base_ch=64):super().__init__()self.down1=nn.Sequential(RepMBCConv(in_ch,base_ch),LocalImportanceAttention(base_ch),)self.down2=nn.Sequential(RepMBCConv(base_ch,base_ch*2,stride=2),LocalImportanceAttention(base_ch*2),)self.up=nn.Upsample(scale_factor=2,mode="bilinear",align_corners=False)self.conv_up=nn.Sequential(RepMBCConv(base_ch*3,base_ch),LocalImportanceAttention(base_ch),)self.out=nn.Sequential(RepMBCConv(base_ch,out_ch),nn.Tanh(),)defforward(self,x):e1=self.down1(x)e2=self.down2(e1)# 128ch, 1/2u=self.up(e2)# 128ch, 1/1u=torch.cat([u,e1],dim=1)# 192chu=self.conv_up(u)# 64chreturnself.out(u)classPlainUSR(nn.Module):def__init__(self,scale=4):super().__init__()self.scale=scale self.backbone=PlainU_NET()defforward(self,x):# Bicubic 上采样到目标尺寸x=F.interpolate(x,scale_factor=self.scale,mode="bilinear",align_corners=False)returnself.backbone(x)model=PlainUSR(scale=4)print(f"参数量:{sum(p.numel()forpinmodel.parameters())/1e6:.2f}M")# ~1.5M

五、训练

importtorch.optimasoptim device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")model=PlainUSR(scale=4).to(device)criterion=nn.L1Loss()optimizer=optim.Adam(model.parameters(),lr=1e-4)train_ds=SRDataset("DIV2K/DIV2K_train_HR",scale=4)train_loader=DataLoader(train_ds,batch_size=16,shuffle=True,num_workers=4)num_epochs=100forepochinrange(num_epochs):model.train()total_loss=0.0forlr,hrintrain_loader:lr,hr=lr.to(device),hr.to(device)optimizer.zero_grad()sr=model(lr)loss=criterion(sr,hr)loss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)if(epoch+1)%10==0:print(f"Epoch{epoch+1:3d}| Loss={avg_loss:.5f}")

训练曲线

Epoch 10 | Loss=0.0832 Epoch 20 | Loss=0.0541 Epoch 30 | Loss=0.0428 Epoch 40 | Loss=0.0367 Epoch 50 | Loss=0.0331 Epoch 70 | Loss=0.0285 Epoch 100 | Loss=0.0243

六、推理 + 重参数化

definference(model,lr_path,output_path):model.eval()# 融合 weight (训练→推理)forminmodel.modules():ifisinstance(m,RepMBCConv):m.fuse()img=Image.open(lr_path).convert("RGB")lr_t=transforms.ToTensor()(img).unsqueeze(0).to(device)lr_t=lr_t*2-1withtorch.no_grad():sr_t=model(lr_t)sr_img=(sr_t.squeeze(0).cpu()+1)/2sr_img=transforms.ToPILImage()(sr_img.clamp(0,1))sr_img.save(output_path)

七、结果

数据集PSNR (dB)SSIM推理时间 (480p)
Set531.420.8951.8ms
Set1428.150.8121.8ms
BSD10027.040.7931.8ms
Urban10025.830.8241.8ms
对比PSNR (Set5, ×4)参数量速度
Bicubic28.43-0ms
EDSR32.4643M25ms
SwinIR32.9212M40ms
PlainUSR31.421.5M1.8ms

八、优化

问题原因解决
PSNR 低于 EDSR参数量只有 1.5M增大 base_ch=96 (3.2M)
纹理不够锐利L1 损失过于平滑加入感知损失 (VGG16 layer=8)
重参数化 fusion 后精度下降BN 融合误差用 100 张校准集微调 fused weight
训练慢Bicubic 上采样全图预下采样 LR 再训练

九、总结

PlainUSR 超分链路:Bicubic 上采样 → RepMBCConv (训练多分支/推理单分支) + LIA (通道注意力) + PlainU-Net (down→up+skip) → Tanh 输出。参数量仅 1.5M (EDSR 的 3.5%), 480p 推理 1.8ms (RTX 3060), PSNR=31.42 (Set5, ×4)。推荐轻量场景 (移动端/实时视频) 使用,若需要最高 PSNR 建议换 HAT/SwinIR。训练 100 epoch 后调用model.fuse()融合重参数化分支再部署。

代码链接与详细流程

飞书链接:https://ecn6838atsup.feishu.cn/wiki/EhRtwBe1CiqlSEkHGUwc5AP9nQe?from=from_copylink
密码:946m22&8
链接可用,不要多复制空格了

http://www.jsqmd.com/news/706589/

相关文章:

  • 通用Mapper + PageHelper:MyBatis分页插件终极实战教程
  • 如何掌握PyTorch Image Models自适应池化层:提升图像分类性能的终极指南
  • 机器学习数据准备:核心技术与实战经验
  • 2025届必备的十大AI辅助写作神器推荐榜单
  • SolidUI:基于AI与RLHF的自然语言图形生成平台架构与实践
  • 2026成都周边健身器材店选型:四川健身器材批发厂家、四川健身房健身器材、四川室外体育健身器材、四川室外健身器材选择指南 - 优质品牌商家
  • 嵌入式轻量级压缩算法Heatshrink解析与应用
  • Appium Inspector不只是查看器:5个提升自动化脚本编写效率的隐藏技巧
  • SpringBoot+Vue小型民营加油站管理系统源码+论文
  • 2026四川优质电缆厂家排名适配重点工程采购:成都电线电缆厂有哪些、成都电线电缆生产厂家、成都电缆厂家有哪些、成都电缆厂电话和地址选择指南 - 优质品牌商家
  • 智能体推理开发指南:从思维链到多智能体协作实战
  • 【2026年拼多多暑期实习/春招- 4月26日-第一题- 多多Token】(题目+思路+JavaC++Python解析+在线测试)
  • 机器学习随机算法实验重复次数的统计确定方法
  • Kala ISO 8601调度语法详解:从基础时间格式到复杂间隔配置
  • BusKill USB安全线缆:硬件级数据保护方案解析
  • 基于eBPF的ingraind安全监控探针:原理、部署与实战指南
  • 位运算技巧终极指南:高效计算与内存优化实战
  • AI智能体技能库:标准化、可复用的模块化开发实践
  • 从MySQL/Oracle迁移到人大金仓:安装后第一件事,用KDTS迁移工具搞定数据和结构
  • 2026年VR虚拟现实开发费用全解析:医疗行业AR开发公司哪家靠谱/四川vr制作公司/国内vr虚拟现实开发公司排行/选择指南 - 优质品牌商家
  • Marzipano 核心组件深度解析:从几何体到渲染器的完整架构
  • Memoh:构建个人知识图谱,打造高效第二大脑
  • 机器学习实验管理的系统化方法与工程实践
  • Geo-Bootstrap实战案例:创建具有90年代魅力的个人作品集网站
  • 开发者必备:开源命令行工具箱Toolmate的设计原理与实战应用
  • SpringBoot+Vue大学生志愿者信息管理系统源码+论文
  • Marzipano 性能优化指南:多分辨率加载与缓存策略
  • Motor Admin与现有系统集成:无缝对接企业应用生态
  • 词嵌入技术解析:从Word2Vec到工业应用
  • 词袋模型原理与NLP文本分类实战指南