当前位置: 首页 > news >正文

YOLOv8改进 - 注意力机制 | CoTAttention (Contextual Transformer Attention) 上下文转换器注意力通过静态与动态上下文协同建模增强视觉表征

前言

本文介绍了上下文Transformer(CoT)块及其在YOLOv8中的结合应用。大多数现有设计未充分利用邻近键的上下文信息,为此提出CoT块。它先通过3×3卷积对输入键进行上下文编码得到静态表示,再与输入查询连接,经两个1×1卷积学习动态多头注意力矩阵,与输入值相乘得到动态表示,最后融合二者作为输出。我们将CoTAttention代码引入指定目录,在ultralytics/nn/tasks.py中注册,配置yolov8-CoTAttention.yaml文件,最后经实验脚本和结果验证了改进的有效性。

文章目录: YOLOv8改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLOv8改进专栏

文章目录

  • 前言
  • 介绍
    • 摘要
  • 文章链接
  • 基本原理
  • 核心代码
  • 引入代码
  • tasks注册
    • 步骤1:
    • 步骤2
  • 配置yolov8-CoTAttention.yaml
  • 实验
    • 脚本
    • 结果

介绍

摘要

Transformer自注意力机制已经引领了自然语言处理领域的革命,并且最近激发了Transformer风格架构设计在众多计算机视觉任务中取得竞争性结果。然而,大多数现有设计直接在二维特征图上使用自注意力机制,以基于每个空间位置的孤立查询和键对来获取注意力矩阵,但没有充分利用邻近键之间的丰富上下文信息。在这项工作中,我们设计了一种新颖的Transformer风格模块,即Contextual Transformer(CoT)块,用于视觉识别。该设计充分利用了输入键之间的上下文信息,以引导动态注意力矩阵的学习,从而增强视觉表示的能力。

在技术上,CoT块首先通过一个3×3卷积对输入键进行上下文编码,导致输入的静态上下文表示。我们进一步将编码后的键与输入查询连接起来,通过两个连续的1×1卷积来学习动态多头注意力矩阵。学习到的注意力矩阵与输入值相乘,以实现输入的动态上下文表示。静态和动态上下文表示的融合最终作为输出。我们提出的CoT块非常有吸引力,因为它可以轻松替换ResNet架构中的每一个3×3卷积,从而生成一种名为Contextual Transformer Networks(CoTNet)的Transformer风格骨干网络。通过在广泛应用(例如图像识别、目标检测和实例分割)中的大量实验,我们验证了CoTNet作为更强骨干网络的优越性。源码可在https://github.com/JDAI-CV/CoTNet获取。

文章链接

论文地址:论文地址

代码地址:代码地址

基本原理

CoTNet是一种基于Contextual Transformer(CoT)模块的网络结构,其原理如下:

  1. CoTNet原理:

    • CoTNet采用Contextual Transformer(CoT)模块作为构建块,用于替代传统的卷积操作。
    • CoT模块利用3×3卷积来对输入键之间的上下文信息进行编码,生成静态上下文表示。
    • 将编码后的键与输入查询连接,通过两个连续的1×1卷积来学习动态多头注意力矩阵。
    • 学习到的注意力矩阵用于聚合所有输入数值,生成动态上下文表示。
    • 最终将静态和动态上下文表示融合作为输出。

  2. Contextual Transformer Attention在CoTNet中的作用和原理:

    Contextual Transformer Attention是Contextual Transformer(CoT)模块中的关键组成部分,用于引导动态学习注意力矩阵,从而增强视觉表示并提高计算机视觉任务的性能

    • Contextual Transformer Attention是CoT模块中的注意力机制,用于引导动态学习注意力矩阵。
    • 通过Contextual Transformer Attention,模型能够充分利用输入键之间的上下文信息,从而更好地捕捉动态关系。
    • 这种注意力机制有助于增强视觉表示,并提高计算机视觉任务的性能。
    • CoTNet通过整合Contextual Transformer Attention,实现了同时进行上下文挖掘和自注意力学习的优势,从而提升了深度网络的表征能力。

核心代码

importtorchfromtorchimportflatten,nnfromtorch.nnimportfunctionalasFclassCoTAttention(nn.Module):def__init__(self,dim=512,kernel_size=3):super().__init__()self.dim=dim# 输入通道数self.kernel_size=kernel_size# 卷积核大小# 关键信息嵌入层,使用分组卷积提取特征self.key_embed=nn.Sequential(nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),nn.BatchNorm2d(dim),# 归一化层nn.ReLU()# 激活函数)# 值信息嵌入层,使用1x1卷积进行特征转换self.value_embed=nn.Sequential(nn.Conv2d(dim,dim,1,bias=False),nn.BatchNorm2d(dim)# 归一化层)# 注意力机制嵌入层,先降维后升维,最终输出与卷积核大小和通道数相匹配的特征factor=4# 降维比例self.attention_embed=nn.Sequential(nn.Conv2d(2*dim,2*dim//factor,1,bias=False),nn.BatchNorm2d(2*dim//factor),# 归一化层nn.ReLU(),# 激活函数nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1)# 升维匹配卷积核形状)defforward(self,x):bs,c,h,w=x.shape# 输入特征的尺寸k1=self.key_embed(x)# 应用关键信息嵌入v=self.value_embed(x).view(bs,c,-1)# 应用值信息嵌入,并展平y=torch.cat([k1,x],dim=1)# 将关键信息和原始输入在通道维度上拼接att=self.attention_embed(y)# 应用注意力机制嵌入层att=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)att=att.mean(2,keepdim=False).view(bs,c,-1)# 计算平均后展平k2=F.softmax(att,dim=-1)*v# 应用softmax进行标准化,并与值信息相乘k2=k2.view(bs,c,h,w)# 重塑形状与输入相同returnk1+k2# 将两部分信息相加并返回

引入代码

在根目录下的ultralytics/nn/目录,新建一个attention目录,然后新建一个以CoTAttention为文件名的py文件, 把代码拷贝进去。

importtorchfromtorchimportflatten,nnfromtorch.nnimportfunctionalasFclassCoTAttention(nn.Module):def__init__(self,dim=512,kernel_size=3):super().__init__()self.dim=dim self.kernel_size=kernel_size self.key_embed=nn.Sequential(nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),nn.BatchNorm2d(dim),nn.ReLU())self.value_embed=nn.Sequential(nn.Conv2d(dim,dim,1,bias=False),nn.BatchNorm2d(dim))factor=4self.attention_embed=nn.Sequential(nn.Conv2d(2*dim,2*dim//factor,1,bias=False),nn.BatchNorm2d(2*dim//factor),nn.ReLU(),nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1))defforward(self,x):bs,c,h,w=x.shape k1=self.key_embed(x)# bs,c,h,wv=self.value_embed(x).view(bs,c,-1)# bs,c,h,wy=torch.cat([k1,x],dim=1)# bs,2c,h,watt=self.attention_embed(y)# bs,c*k*k,h,watt=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)att=att.mean(2,keepdim=False).view(bs,c,-1)# bs,c,h*wk2=F.softmax(att,dim=-1)*v k2=k2.view(bs,c,h,w)returnk1+k2

tasks注册

ultralytics/nn/tasks.py中进行如下操作:

步骤1:

fromultralytics.nn.attention.CoTAttentionimportCoTAttention

步骤2

修改def parse_model(d, ch, verbose=True):

elifmisCoTAttention:c1,c2=ch[f],args[0]ifc2!=nc:c2=make_divisible(min(c2,max_channels)*width,8)args=[c1,*args[1:]]

配置yolov8-CoTAttention.yaml

ultralytics/ultralytics/cfg/models/v8/yolov8-CoTAttention.yaml

# Ultralytics YOLO 🚀, GPL-3.0 license# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parametersnc:80# number of classesscales:# model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n:[0.33,0.25,1024]# YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss:[0.33,0.50,1024]# YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm:[0.67,0.75,768]# YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl:[1.00,1.00,512]# YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx:[1.00,1.25,512]# YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbonebackbone:# [from, repeats, module, args]-[-1,1,Conv,[64,3,2]]# 0-P1/2-[-1,1,Conv,[128,3,2]]# 1-P2/4-[-1,3,C2f,[128,True]]-[-1,1,Conv,[256,3,2]]# 3-P3/8-[-1,6,C2f,[256,True]]-[-1,1,Conv,[512,3,2]]# 5-P4/16-[-1,6,C2f,[512,True]]-[-1,1,Conv,[1024,3,2]]# 7-P5/32-[-1,3,C2f,[1024,True]]-[-1,1,SPPF,[1024,5]]# 9# YOLOv8.0n headhead:-[-1,1,nn.Upsample,[None,2,'nearest']]-[[-1,6],1,Concat,[1]]# cat backbone P4-[-1,3,C2f,[512]]# 12-[-1,1,nn.Upsample,[None,2,'nearest']]-[[-1,4],1,Concat,[1]]# cat backbone P3-[-1,3,C2f,[256]]# 15 (P3/8-small)-[-1,3,CoTAttention,[256]]# 16-[-1,1,Conv,[256,3,2]]-[[-1,12],1,Concat,[1]]# cat head P4-[-1,3,C2f,[512]]# 19 (P4/16-medium)-[-1,3,CoTAttention,[512]]-[-1,1,Conv,[512,3,2]]-[[-1,9],1,Concat,[1]]# cat head P5-[-1,3,C2f,[1024]]# 23 (P5/32-large)-[-1,3,CoTAttention,[1024]]-[[16,20,24],1,Detect,[nc]]# Detect(P3, P4, P5)

实验

脚本

importosfromultralyticsimportYOLO yaml='ultralytics/cfg/models/v8/yolov8-CoTAttention.yaml'model=YOLO(yaml)model.info()if__name__=="__main__":results=model.train(data='coco128.yaml',name='CoTAttention',epochs=10,workers=8,batch=1)

结果

http://www.jsqmd.com/news/280506/

相关文章:

  • 【大数据毕设源码分享】基于python+Hadoop+数据可视化的租房数据分析系统的设计与实现(程序+文档+代码讲解+一条龙定制)
  • C#/.NET/.NET Core技术前沿周刊 | 第 66 期(2026年1.12-1.18)
  • 实用指南:清楚易懂的红黑树讲解
  • Java计算机毕设之基于springboot的元宇宙平台的房屋租赁管理系统基于springboot + vue房屋租赁管理系统(完整前后端代码+说明文档+LW,调试定制等)
  • 迈向意义共治的智能文明:一份关于AI时代新范式的框架性阐述
  • 学习日记之狂神说Java
  • [note] 本地12+16G极限部署 Qwen3-Coder-25B 搭配Continue插件实现代码补全
  • Java计算机毕设之基于springboot的婚庆公司服务平台的设计与实现婚庆摄影(完整前后端代码+说明文档+LW,调试定制等)
  • Java毕设项目:基于springboot的婚庆公司服务平台的设计与实现(源码+文档,讲解、调试运行,定制等)
  • 【性能测试】14_JMeter _JMeter测试报告
  • 【毕业设计】基于springboot的实验设备借用平台的设计与实现 实验室设备租赁系统(源码+文档+远程调试,全bao定制等)
  • Java毕设选题推荐:基于SpringBoot+Vue+MySQL 房屋租赁管理系统平台基于springboot的元宇宙平台的房屋租赁管理系统【附源码、mysql、文档、调试+代码讲解+全bao等】
  • 2026必备!10个AI论文工具,助本科生轻松写论文!
  • 【课程设计/毕业设计】基于springboot+vue的婚庆公司服务网站管理系统基于springboot的婚庆公司服务平台的设计与实现【附源码、数据库、万字文档】
  • K8s新手入门:从“Pod创建”到“服务暴露”,3个案例理解容器编排
  • 【旋转式多线激光雷达】旋转式多线激光雷达工作原理
  • ClickHouse在农业大数据分析中的创新应用
  • agentscope记忆模块使用和部署agent-memory-server记忆服务
  • 【毕业设计】基于springboot的婚庆公司服务平台的设计与实现(源码+文档+远程调试,全bao定制等)
  • 在决策树生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点。 - 指南
  • AI Agent核心技术揭秘:概念辨析、商业化路径与实践指南,值得收藏
  • Java程序员转型大模型开发全攻略:月薪30K+的AI工程师成长路径_程序员转行AI大模型教程(非常详细)
  • docker部署及基本要点
  • 无线网络仿真:无线网络基础_(19).网络协议栈仿真
  • 【大数据毕设全套源码+文档】基于Python+数据可视化的黑龙江旅游景点数据分析系统的设计与实现(丰富项目+远程调试+讲解+定制)
  • 【大数据毕设全套源码+文档】基于springboot+大数据的音乐数据分析系统的设计与实现(丰富项目+远程调试+讲解+定制)
  • 资治通鉴对于大赦天下的评价
  • docker安装部署PostgreSQL带有pgvector扩展向量数据(高维数组)
  • 【大数据毕设全套源码+文档】基于springboot+Hadoop的手机销售数据分析系统的设计与实现(丰富项目+远程调试+讲解+定制)
  • 基于提供的镜像构建PostGIS、pgvector 的 PostgreSQL 18镜像的Dockerfile