当前位置: 首页 > news >正文

YOLO26改进 - 采样 | ICCV 顶会技术:WaveletPool 小波池化强化采样,保留小目标细节

前言

本文介绍了基于小波变换的池化方法——Wavelet Pooling,作为传统最大池化与平均池化的有效替代方案。该方法通过两级小波分解丢弃高频子带,保留更具代表性的低频特征,从而在减少信息丢失的同时提升模型的正则化能力。我们将 Wavelet Pool 和 UnPool 成功集成进 YOLO26,替代原有的下采样与上采样模块,实现更高效的特征提取与恢复。实验证明,YOLO26-WaveletPool 在多个分类与检测任务中均取得优异表现,展现了小波池化在深度学习中的广泛应用前景。

文章目录: YOLO26改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLO26改进专栏

文章目录

  • 前言
  • 介绍
    • 摘要
  • 文章链接
  • 基本原理:
      • **小波变换的基本原理**
      • **论文的方法**
  • 核心代码
  • YOLO26引入代码
  • tasks注册
    • 步骤1:导包:
    • 步骤2
  • 配置yolo26-WaveletPool.yaml
  • 实验
    • 脚本
    • 结果

介绍

摘要

卷积神经网络(Convolutional Neural Networks, CNNs)持续推动着二维和三维图像分类及目标识别技术的发展。然而,为了维持这一快速进展,有必要对神经网络中的基础构件进行持续的评估与改进。当前主流的网络正则化方法大多侧重于卷积操作本身,而对池化层的设计选择关注不足。

为此,我们提出了一种新的池化策略——小波池化(Wavelet Pooling),作为传统邻域池化方法(如最大池化和平均池化)的有效替代方案。该方法通过将特征分解为多层小波子带,并舍弃第一层级的高频子带来实现下采样,从而有效降低特征维度。与最大池化中常见的过拟合问题不同,小波池化在降维过程中保留了更多结构信息,具备更强的泛化能力。此外,相比于基于固定邻域的池化方式,小波池化在结构上实现了更紧凑、高效的特征压缩。

我们在四个标准图像分类数据集上进行了系统实验,结果表明:所提出的小波池化方法在性能上显著优于或与最大池化、平均池化、混合池化以及随机池化等主流方法相当,验证了其作为通用池化策略的潜力。

文章链接

论文地址:论文地址

代码地址:代码地址

论文地址:论文地址

基本原理:

首先,池化是一种通过舍弃信息实现正则化效果的操作。然而,传统的池化方法存在一些不足:

  • Max pooling:当重要特征的幅度值低于不重要特征时,重要特征会被忽略。
  • Average pooling:同时接纳幅值大和幅值小的特征,容易稀释关键特征。

为了解决这些问题,该论文提出基于小波变换的池化操作,具体思路如下:


小波变换的基本原理

小波变换可将输入特征图划分为低频子带(LL)和高频子带(LH、HL、HH)。其数学公式为:

  • 一级小波变换:
    L L 1 , L H 1 , H L 1 , H H 1 = D W T ( I ) LL1, LH1, HL1, HH1 = DWT(I)LL1,LH1,HL1,HH1=DWT(I)
    逆变换:
    I = I D W T ( L L 1 , L H 1 , H L 1 , H H 1 ) I = IDWT(LL1, LH1, HL1, HH1)I=IDWT(LL1,LH1,HL1,HH1)

  • 二级小波变换:
    L L 2 , L H 2 , H L 2 , H H 2 = D W T ( L L 1 ) LL2, LH2, HL2, HH2 = DWT(LL1)LL2,LH2,HL2,HH2=DWT(LL1)
    逆变换:
    L L 1 = I D W T ( L L 2 , L H 2 , H L 2 , H H 2 ) LL1 = IDWT(LL2, LH2, HL2, HH2)LL1=IDWT(LL2,LH2,HL2,HH2)

小波变换通过下采样将特征图尺寸缩小一半,逆变换可完美重建原始图像。


论文的方法

该论文方法流程如下:

  1. 对输入图像I II进行两次小波变换,得到:
    L L 2 , ( L H 2 , H L 2 , H H 2 ) , ( L H 1 , H L 1 , H H 1 ) = D W T ( D W T ( I ) ) LL2, (LH2, HL2, HH2), (LH1, HL1, HH1) = DWT(DWT(I))LL2,(LH2,HL2,HH2),(LH1,HL1,HH1)=DWT(DWT(I))
  2. 舍弃最高频子带( L H 1 , H L 1 , H H 1 ) (LH1, HL1, HH1)(LH1,HL1,HH1),保留低频子带( L L 2 , L H 2 , H L 2 , H H 2 ) (LL2, LH2, HL2, HH2)(LL2,LH2,HL2,HH2)
  3. 对保留的二级小波系数进行逆变换,重建池化后的图像:
    I ′ = I D W T ( L L 2 , L H 2 , H L 2 , H H 2 ) I' = IDWT(LL2, LH2, HL2, HH2)I=IDWT(LL2,LH2,HL2,HH2)

核心代码

classWaveletPool(nn.Module):def__init__(self):super(WaveletPool,self).__init__()ll=np.array([[0.5,0.5],[0.5,0.5]])lh=np.array([[-0.5,-0.5],[0.5,0.5]])hl=np.array([[-0.5,0.5],[-0.5,0.5]])hh=np.array([[0.5,-0.5],[-0.5,0.5]])filts=np.stack([ll[None,::-1,::-1],lh[None,::-1,::-1],hl[None,::-1,::-1],hh[None,::-1,::-1]],axis=0)self.weight=nn.Parameter(torch.tensor(filts).to(torch.get_default_dtype()),requires_grad=False)defforward(self,x):C=x.shape[1]filters=torch.cat([self.weight,]*C,dim=0)y=F.conv2d(x,filters,groups=C,stride=2)returnyclassWaveletUnPool(nn.Module):def__init__(self):super(WaveletUnPool,self).__init__()ll=np.array([[0.5,0.5],[0.5,0.5]])lh=np.array([[-0.5,-0.5],[0.5,0.5]])hl=np.array([[-0.5,0.5],[-0.5,0.5]])hh=np.array([[0.5,-0.5],[-0.5,0.5]])filts=np.stack([ll[None,::-1,::-1],lh[None,::-1,::-1],hl[None,::-1,::-1],hh[None,::-1,::-1]],axis=0)self.weight=nn.Parameter(torch.tensor(filts).to(torch.get_default_dtype()),requires_grad=False)defforward(self,x):C=torch.floor_divide(x.shape[1],4)filters=torch.cat([self.weight,]*C,dim=0)y=F.conv_transpose2d(x,filters,groups=C,stride=2)returny

YOLO26引入代码

在根目录下的ultralytics/nn/目录,新建一个sample目录,然后新建一个以WaveletPool为文件名的py文件, 把代码拷贝进去。

importtorchfromtorchimportnnasnnimporttorch.nn.functionalasFimportnumpyasnpclassWaveletPool(nn.Module):def__init__(self):super(WaveletPool,self).__init__()ll=np.array([[0.5,0.5],[0.5,0.5]])lh=np.array([[-0.5,-0.5],[0.5,0.5]])hl=np.array([[-0.5,0.5],[-0.5,0.5]])hh=np.array([[0.5,-0.5],[-0.5,0.5]])filts=np.stack([ll[None,::-1,::-1],lh[None,::-1,::-1],hl[None,::-1,::-1],hh[None,::-1,::-1]],axis=0)self.weight=nn.Parameter(torch.tensor(filts).to(torch.get_default_dtype()),requires_grad=False)defforward(self,x):C=x.shape[1]filters=torch.cat([self.weight,]*C,dim=0)y=F.conv2d(x,filters,groups=C,stride=2)returnyclassWaveletUnPool(nn.Module):def__init__(self):super(WaveletUnPool,self).__init__()ll=np.array([[0.5,0.5],[0.5,0.5]])lh=np.array([[-0.5,-0.5],[0.5,0.5]])hl=np.array([[-0.5,0.5],[-0.5,0.5]])hh=np.array([[0.5,-0.5],[-0.5,0.5]])filts=np.stack([ll[None,::-1,::-1],lh[None,::-1,::-1],hl[None,::-1,::-1],hh[None,::-1,::-1]],axis=0)self.weight=nn.Parameter(torch.tensor(filts).to(torch.get_default_dtype()),requires_grad=False)defforward(self,x):C=torch.floor_divide(x.shape[1],4)filters=torch.cat([self.weight,]*C,dim=0)y=F.conv_transpose2d(x,filters,groups=C,stride=2)returny

tasks注册

ultralytics/nn/tasks.py中进行如下操作:

步骤1:导包:

fromultralytics.nn.sample.WaveletPoolimportWaveletPool,WaveletUnPool

步骤2

修改def parse_model(d, ch, verbose=True):
只需要添加截图中标明的,其他没有的模块不用添加

elifmisWaveletPool:c2=ch[f]*4elifmisWaveletUnPool:c2=ch[f]//4

配置yolo26-WaveletPool.yaml

ultralytics/cfg/models/26/yolo26-WaveletPool.yaml

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license# Ultralytics YOLO26 object detection model with P3/8 - P5/32 outputs# Model docs: https://docs.ultralytics.com/models/yolo26# Task docs: https://docs.ultralytics.com/tasks/detect# Parametersnc:80# number of classesend2end:True# whether to use end-to-end modereg_max:1# DFL binsscales:# model compound scaling constants, i.e. 'model=yolo26n.yaml' will call yolo26.yaml with scale 'n'# [depth, width, max_channels]n:[0.50,0.25,1024]# summary: 260 layers, 2,572,280 parameters, 2,572,280 gradients, 6.1 GFLOPss:[0.50,0.50,1024]# summary: 260 layers, 10,009,784 parameters, 10,009,784 gradients, 22.8 GFLOPsm:[0.50,1.00,512]# summary: 280 layers, 21,896,248 parameters, 21,896,248 gradients, 75.4 GFLOPsl:[1.00,1.00,512]# summary: 392 layers, 26,299,704 parameters, 26,299,704 gradients, 93.8 GFLOPsx:[1.00,1.50,512]# summary: 392 layers, 58,993,368 parameters, 58,993,368 gradients, 209.5 GFLOPs# YOLO26n backbonebackbone:# [from, repeats, module, args]-[-1,1,Conv,[64,3,2]]# 0-P1/2-[-1,1,WaveletPool,[]]# 1-P2/4-[-1,2,C3k2,[256,False,0.25]]-[-1,1,WaveletPool,[]]# 3-P3/8-[-1,2,C3k2,[512,False,0.25]]-[-1,1,WaveletPool,[]]# 5-P4/16-[-1,2,C3k2,[512,True]]-[-1,1,WaveletPool,[]]# 5-P4/16-[-1,2,C3k2,[1024,True]]-[-1,1,SPPF,[1024,5,3,True]]# 9-[-1,2,C2PSA,[1024]]# 10# YOLO26n headhead:-[-1,1,WaveletUnPool,[]]-[[-1,6],1,Concat,[1]]# cat backbone P4-[-1,2,C3k2,[512,True]]# 13-[-1,1,WaveletUnPool,[]]-[[-1,4],1,Concat,[1]]# cat backbone P3-[-1,2,C3k2,[256,True]]# 16 (P3/8-small)-[-1,1,Conv,[256,3,2]]-[[-1,13],1,Concat,[1]]# cat head P4-[-1,2,C3k2,[512,True]]# 19 (P4/16-medium)-[-1,1,Conv,[512,3,2]]-[[-1,10],1,Concat,[1]]# cat head P5-[-1,1,C3k2,[1024,True,0.5,True]]# 22 (P5/32-large)-[[16,19,22],1,Detect,[nc]]# Detect(P3, P4, P5)

实验

脚本

importwarnings warnings.filterwarnings('ignore')fromultralyticsimportYOLOif__name__=='__main__':# 修改为自己的配置文件地址model=YOLO('./ultralytics/cfg/models/26/yolo26-WaveletPool.yaml')# 修改为自己的数据集地址model.train(data='./ultralytics/cfg/datasets/coco8.yaml',cache=False,imgsz=640,epochs=10,single_cls=False,# 是否是单类别检测batch=8,close_mosaic=10,workers=0,optimizer='MuSGD',# optimizer='SGD',amp=False,project='runs/train',name='yolo26-WaveletPool',)

结果

http://www.jsqmd.com/news/299646/

相关文章:

  • P1948 [USACO08JAN] Telephone Lines S
  • 深度测评10个AI论文平台,研究生高效写作必备!
  • 图神经网络分享系列-GGNN(GATED GRAPH SEQUENCE NEURAL NETWORKS)(三)
  • 音视频学习(八十六):宏块
  • 完整教程:(数据结构)栈和队列
  • day11|150. 逆波兰表达式求值 239. 滑动窗口最大值 347.前 K 个高频元素
  • 求多个乘法逆元(模板)
  • 语义分割实战——基于EGEUNet神经网络印章分割系统3:含训练测试代码、数据集和GUI交互界面
  • 语义分割实战——基于EGEUNet神经网络印章分割系统2:含训练测试代码和数据集
  • 语义分割实战——基于EGEUNet神经网络印章分割系统1:数据集说明(含下载链接)
  • 强烈安利!本科生毕业论文必备TOP8 AI论文网站测评
  • STM32F0实战:基于HAL库开发【2.3】
  • 工信部教考中心《系统可靠性工程师(高级)》开课通知
  • 机房U位管理别瞎忙!这套系统让运维效率翻倍
  • 告别设备束缚 RetroArch-web 把童年游戏装进口袋,cpolar解锁全场景游玩
  • 使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 44--Pytest框架钩子函数
  • 使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 43--添加allure测试报告显示信息和其他封装方法
  • 云端VS本地 RFID资产管理系统怎么选?优缺点大揭秘
  • Transactional失效的情况总结
  • Spark GIS:分布式计算框架下的空间数据分析
  • 2023年NOC大赛创客智慧编程赛项Python复赛模拟题(一)
  • 2023年NOC大赛创客智慧编程赛项Python复赛模拟题(二)
  • Python大数据项目推荐:基于Hadoop+Spark电商用户行为分析毕设 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
  • Flutter for OpenHarmony 剧本杀组队App实战22:快速匹配功能实现
  • 【计算机毕设选题】基于Spark的双十一美妆数据可视化系统源码 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
  • Reap
  • 信号处理仿真:滤波器设计与仿真_23.滤波器设计与仿真在雷达系统中的应用
  • 信号处理仿真:滤波器设计与仿真_24.滤波器设计与仿真在控制工程中的应用
  • 性价比对比视角|四款热门机型性价比深度拆解
  • 中国智能体应用现状与企业实践