当前位置: 首页 > news >正文

YOLOv8改进 - 注意力机制 | MSCA (Multi-Scale Convolutional Attention) 即插即用增强复杂场景小目标检测鲁棒性

前言

本文介绍了多尺度卷积注意力(MSCA)及其在YOLOv8中的结合应用。基于变换器的模型在语义分割领域占主导,但卷积注意力在编码上下文信息方面更高效。MSCA由深度卷积聚合局部信息、多分支深度卷积捕获多尺度上下文信息、1×1逐点卷积模拟通道关系三部分组成。我们将MSCA代码引入指定目录,在ultralytics/nn/tasks.py中注册,配置yolov8_MSCA.yaml文件,最后通过实验脚本和结果验证了改进的有效性。

文章目录: YOLOv8改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLOv8改进专栏

文章目录

  • 前言
  • 介绍
    • 摘要
  • 文章链接
  • 基本原理
  • 参考代码
  • 引入代码
  • tasks.py 注册
    • 步骤1:
    • 步骤2
  • 配置yolov8_MSCA.yaml
  • 实验
    • 脚本
    • 结果

介绍

摘要

我们提出了SegNeXt,一种用于语义分割的简单卷积网络架构。最近基于变换器的模型由于自注意力在编码空间信息方面的效率而在语义分割领域占据主导地位。在本文中,我们展示了卷积注意力是一种比变换器中的自注意力机制更高效和有效的编码上下文信息的方式。通过重新审视成功的分割模型所拥有的特征,我们发现了几个关键组件,这些组件导致了分割模型性能的提升。这激励我们设计了一种新颖的卷积注意力网络,该网络使用廉价的卷积操作。没有任何花哨的技巧,我们的SegNeXt在包括ADE20K、Cityscapes、COCO-Stuff、Pascal VOC、Pascal Context和iSAID在内的流行基准测试上,显著提高了先前最先进方法的性能。值得注意的是,SegNeXt超越了EfficientNet-L2 w/ NAS-FPN,在Pascal VOC 2012测试排行榜上仅使用1/10的参数就达到了90.6%的mIoU。平均而言,与最先进的方法相比,SegNeXt在ADE20K数据集上的mIoU提高了约2.0%,同时计算量相同或更少。

文章链接

论文地址:论文地址

中文论文:论文地址

代码地址:代码地址

参考代码地址:参考代码地址

基本原理

MSCA 主要由三个部分组成:(1)一个深度卷积用于聚 合局部信息;(2)多分支深度卷积用于捕获多尺度上下文信息;(3)一个 1 × 1 逐点卷积用于模拟特征中不同通道之间的关系。1 × 1 逐点卷积的输出被直接用 作卷积注意力的权重,以重新权衡 MSCA 的输入。

MSCA 可以写成 如下形式:其中 F 代表输入特征,Att 和 Out 分别为注意力权重和输出,⊗ 表示逐元素的矩 阵乘法运算,DW­Conv 表示深度卷积,Scalei (i ∈ {0, 1, 2, 3}) 表示上图右边侧图中的第 i 个分支,Scale0 为残差连接。遵循[130],在 MSCA 的每个分支中,SegNeXt 使用两个深度条带卷积来近似模拟大卷积核的深度卷积。每个分支的卷积核大 小分别被设定为 7、11 和 21。 选择深度条带卷积主要考虑到以下两方面原 因:一方面,相较于普通卷积,条带卷积更加轻量化。为了模拟核大小为 7 × 7 的标准二维卷积,只需使用一对 7 × 1 和 1 × 7 的条带卷积。另一方面,在实际 的分割场景中存在一些条状物体,例如人和电线杆。因此,条状卷积可以作为 标准网格状的卷积的补充,有助于提取条状特征。

参考代码

下面代码来源于

https://github.com/open-mmlab/mmsegmentation/blob/c685fe6767c4cadf6b051983ca6208f1b9d1ccb8/mmseg/models/backbones/mscan.py#L115

classMSCAAttention(BaseModule):"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA). Args: channels (int): The dimension of channels. kernel_sizes (list): The size of attention kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. paddings (list): The number of corresponding padding value in attention module. Defaults: [2, [0, 3], [0, 5], [0, 10]]. """def__init__(self,channels,kernel_sizes=[5,[1,7],[1,11],[1,21]],paddings=[2,[0,3],[0,5],[0,10]]):super().__init__()self.conv0=nn.Conv2d(channels,channels,kernel_size=kernel_sizes[0],padding=paddings[0],groups=channels)fori,(kernel_size,padding)inenumerate(zip(kernel_sizes[1:],paddings[1:])):kernel_size_=[kernel_size,kernel_size[::-1]]padding_=[padding,padding[::-1]]conv_name=[f'conv{i}_1',f'conv{i}_2']fori_kernel,i_pad,i_convinzip(kernel_size_,padding_,conv_name):self.add_module(i_conv,nn.Conv2d(channels,channels,tuple(i_kernel),padding=i_pad,groups=channels))self.conv3=nn.Conv2d(channels,channels,1)defforward(self,x):"""Forward function."""u=x.clone()attn=self.conv0(x)# Multi-Scale Feature extractionattn_0=self.conv0_1(attn)attn_0=self.conv0_2(attn_0)attn_1=self.conv1_1(attn)attn_1=self.conv1_2(attn_1)attn_2=self.conv2_1(attn)attn_2=self.conv2_2(attn_2)attn=attn+attn_0+attn_1+attn_2# Channel Mixingattn=self.conv3(attn)# Convolutional Attentionx=attn*ureturnx

下面代码来源于

https://zhuanlan.zhihu.com/p/566607168

classAttentionModule(BaseModule):def__init__(self,dim):super().__init__()self.conv0=nn.Conv2d(dim,dim,5,padding=2,groups=dim)self.conv0_1=nn.Conv2d(dim,dim,(1,7),padding=(0,3),groups=dim)self.conv0_2=nn.Conv2d(dim,dim,(7,1),padding=(3,0),groups=dim)self.conv1_1=nn.Conv2d(dim,dim,(1,11),padding=(0,5),groups=dim)self.conv1_2=nn.Conv2d(dim,dim,(11,1),padding=(5,0),groups=dim)self.conv2_1=nn.Conv2d(dim,dim,(1,21),padding=(0,10),groups=dim)self.conv2_2=nn.Conv2d(dim,dim,(21,1),padding=(10,0),groups=dim)self.conv3=nn.Conv2d(dim,dim,1)defforward(self,x):u=x.clone()attn=self.conv0(x)attn_0=self.conv0_1(attn)attn_0=self.conv0_2(attn_0)attn_1=self.conv1_1(attn)attn_1=self.conv1_2(attn_1)attn_2=self.conv2_1(attn)attn_2=self.conv2_2(attn_2)attn=attn+attn_0+attn_1+attn_2 attn=self.conv3(attn)returnattn*u

引入代码

在根目录下的ultralytics/nn/目录,新建一个attention目录,然后新建一个以MSCA为文件名的py文件, 把代码拷贝进去。

importtorchimporttorch.nnasnnfromtorch.nnimportfunctionalasFclassMSCAAttention(nn.Module):def__init__(self,dim):super().__init__()self.conv0=nn.Conv2d(dim,dim,5,padding=2,groups=dim)self.conv0_1=nn.Conv2d(dim,dim,(1,7),padding=(0,3),groups=dim)self.conv0_2=nn.Conv2d(dim,dim,(7,1),padding=(3,0),groups=dim)self.conv1_1=nn.Conv2d(dim,dim,(1,11),padding=(0,5),groups=dim)self.conv1_2=nn.Conv2d(dim,dim,(11,1),padding=(5,0),groups=dim)self.conv2_1=nn.Conv2d(dim,dim,(1,21),padding=(0,10),groups=dim)self.conv2_2=nn.Conv2d(dim,dim,(21,1),padding=(10,0),groups=dim)self.conv3=nn.Conv2d(dim,dim,1)defforward(self,x):u=x.clone()attn=self.conv0(x)attn_0=self.conv0_1(attn)attn_0=self.conv0_2(attn_0)attn_1=self.conv1_1(attn)attn_1=self.conv1_2(attn_1)attn_2=self.conv2_1(attn)attn_2=self.conv2_2(attn_2)attn=attn+attn_0+attn_1+attn_2 attn=self.conv3(attn)returnattn*u

tasks.py 注册

ultralytics/nn/tasks.py中进行如下操作:

步骤1:

fromultralytics.nn.attention.MSCAimportMSCAAttention

步骤2

修改def parse_model(d, ch, verbose=True):

只需要添加截图中标明的,其他没有的模块不用添加。

elif m in {MSCAAttention}: c2 = ch[f] args = [c2, *args]

配置yolov8_MSCA.yaml

ultralytics/cfg/models/v8/yolov8_MSCA.yaml

# Ultralytics YOLO 🚀, GPL-3.0 license# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parametersnc:2# number of classesscales:# model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n:[0.33,0.25,1024]# YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss:[0.33,0.50,1024]# YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm:[0.67,0.75,768]# YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl:[1.00,1.00,512]# YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx:[1.00,1.25,512]# YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbonebackbone:# [from, repeats, module, args]-[-1,1,Conv,[64,3,2]]# 0-P1/2-[-1,1,Conv,[128,3,2]]# 1-P2/4-[-1,3,C2f,[128,True]]-[-1,1,Conv,[256,3,2]]# 3-P3/8-[-1,6,C2f,[256,True]]-[-1,1,Conv,[512,3,2]]# 5-P4/16-[-1,6,C2f,[512,True]]-[-1,1,Conv,[1024,3,2]]# 7-P5/32-[-1,3,C2f,[1024,True]]-[-1,1,SPPF,[1024,5]]# 9-[-1,1,MSCAAttention,[]]# 10# YOLOv8.0n headhead:-[-1,1,nn.Upsample,[None,2,'nearest']]-[[-1,6],1,Concat,[1]]# cat backbone P4-[-1,3,C2f,[512]]# 13-[-1,1,nn.Upsample,[None,2,'nearest']]-[[-1,4],1,Concat,[1]]# cat backbone P3-[-1,3,C2f,[256]]# 16 (P3/8-small)-[-1,1,Conv,[256,3,2]]-[[-1,13],1,Concat,[1]]# cat head P4-[-1,3,C2f,[512]]# 19 (P4/16-medium)-[-1,1,Conv,[512,3,2]]-[[-1,10],1,Concat,[1]]# cat head P5-[-1,3,C2f,[1024]]# 22 (P5/32-large)-[[16,19,22],1,Detect,[nc]]# Detect(P3, P4, P5)

实验

脚本

importosfromultralyticsimportYOLO yaml='ultralytics/cfg/models/v8/yolov8_MSCA.yaml'model=YOLO(yaml)model.info()if__name__=="__main__":results=model.train(data='ultralytics/datasets/original-license-plates.yaml',name='yolov8_MSCA',epochs=10,workers=8,batch=1)

结果

http://www.jsqmd.com/news/250094/

相关文章:

  • YOLOv8改进 - 注意力机制 | CAA (Context Anchor Attention) 上下文锚点注意力增强复杂场景多尺度目标特征感知
  • AI驱动供应商管理,AI应用架构师引领供应链数字化变革
  • 【课程设计/毕业设计】基于深度学习python-pytorch训练识别舌头是否健康
  • YOLOv8改进 - 注意力机制 | HaloNet 局部自注意力网络通过分块与扩展感受野实现高效空间交互建模
  • YOLOv8改进 - 注意力机制 | MCA (Multidimensional Collaborative Attention) 多维协作注意力通过三分支结构增强通道与空间特征协同建模
  • 深度学习毕设选题推荐:基于python-pytorch卷神经网络训练识别舌头是否健康
  • 深度学习毕设选题推荐:基于python机器学习-pytorch-CNN训练识别服装服饰
  • 7个数据安全策略保证YashanDB的安全执行
  • Rust unsafe 一文全功能解析
  • 7个为什么选择YashanDB的理由,助力企业决策
  • 代码混淆的AI优化:安全性与性能平衡
  • 深度学习毕设项目推荐-基于python-pytorch训练识别舌头是否健康
  • Java毕设项目推荐-基于Web的校运动会管理系统设计与实现基于SpringBoot的民运会赛务管理系统的设计与实现【附源码+文档,调试定制服务】
  • 计算机深度学习毕设实战-基于机器学习 python-pytorch训练识别舌头是否健康
  • Java毕设项目推荐-基于java的车辆违章信息管理系统的设计与实现基于JavaEE的车辆违章信息管理系统的设计与实现【附源码+文档,调试定制服务】
  • 7种常见的YashanDB数据库故障及处理办法
  • 手把手教你:提示工程架构师完成提示工程系统持续部署
  • 深度学习毕设项目推荐-基于python深度学习的道路车辆内有无佩戴安全带识别
  • 深度学习毕设项目:基于python-pytorch机器学习 训练识别舌头是否健康
  • 深度学习毕设项目推荐-基于python-pytorch-CNN训练识别服装服饰
  • 8个步骤快速部署YashanDB数据库环境
  • 深度学习计算机毕设之基于python-pytorch训练识别舌头是否健康卷神经网络
  • 【毕业设计】基于python-pytorch深度学习训练识别舌头是否健康
  • django Python在线学习网站的设计与实现
  • 强烈安利9个AI论文网站,研究生高效写作必备!
  • 【毕业设计】基于python深度学习的道路车辆内有无佩戴安全带识别
  • 亲测好用9个AI论文工具,继续教育学生轻松写论文!
  • django公务员应届生复习备考平台
  • 【课程设计/毕业设计】基于深度学习python-pytorch-CNN训练识别服装服饰
  • 什么是SPN网络