当前位置: 首页 > news >正文

DAY 47 通道注意力(SE注意力)

一、注意力

注意力机制是一种让模型学会「选择性关注重要信息」的特征提取器,就像人类视觉会自动忽略背景,聚焦于图片中的主体(如猫、汽车)。

transformer中的叫做自注意力机制,他是一种自己学习自己的机制,他可以自动学习到图片中的主体,并忽略背景。我们现在说的很多模块,比如通道注意力、空间注意力、通道注意力等等,都是基于自注意力机制的。

从数学角度看,注意力机制是对输入特征进行加权求和,输出=∑(输入特征×注意力权重),其中注意力权重是学习到的。所以他和卷积很像,因为卷积也是一种加权求和。但是卷积是 “固定权重” 的特征提取(如 3x3 卷积核)--训练完了就结束了,注意力是 “动态权重” 的特征提取(权重随输入数据变化)---输入数据不同权重不同。

问:为什么需要多种注意力模块?

答:因为不同场景下的关键信息分布不同。例如,识别鸟类和飞机时,需关注 “羽毛纹理”“金属光泽” 等特定通道的特征,通道注意力可强化关键通道;而物体位置不确定时(如猫出现在图像不同位置),空间注意力能聚焦物体所在区域,忽略背景。复杂场景中,可能需要同时关注通道和空间(如混合注意力模块 CBAM),或处理长距离依赖(如全局注意力模块 Non-local)。

问:为什么不设计一个‘万能’注意力模块?

答:主要受效率和灵活性限制。专用模块针对特定需求优化计算,成本更低(如通道注意力仅需处理通道维度,无需全局位置计算);不同任务的核心需求差异大(如医学图像侧重空间定位,自然语言处理侧重语义长距离依赖),通用模块可能冗余或低效。每个模块新增的权重会增加模型参数量,若训练数据不足或优化不当,可能引发过拟合。因此实际应用中需结合轻量化设计(如减少全连接层参数)、正则化(如 Dropout)或结构约束(如共享注意力权重)来平衡性能与复杂度。

通道注意力(Channel Attention)属于注意力机制(Attention Mechanism)的变体,而非自注意力(Self-Attention)的直接变体。可以理解为注意力是一个动物园算法,里面很多个物种,自注意力只是一个分支,因为开创了transformer所以备受瞩目。

常见注意力模块的归类如下:

注意力模块

所属类别

核心功能

自注意力(Self-Attention)

自注意力变体

建模同一输入内部元素的依赖(如序列位置、图像块)

通道注意力(Channel Attention)

普通注意力变体(全局上下文)

建模特征图通道间的重要性,通过全局池化压缩空间信息

空间注意力(Spatial Attention)

普通注意力变体(全局上下文)

建模特征图空间位置的重要性,关注“哪里”更重要

多头注意力(Multi-Head Attention)

自注意力/普通注意力的增强版

将query/key/value投影到多个子空间,捕捉多维度依赖

编码器-解码器注意力(Encoder-Decoder Attention)

普通注意力变体

建模编码器输出与解码器输入的跨模态交互(如机器翻译中句子与译文的对齐)

二、通道注意力

想要把通道注意力插入到模型中,关键步骤如下:

(1)定义注意力模块

(2)重写之前的模型定义部分,确定好模块插入的位置

1.通道注意力的定义

# ===================== 新增:通道注意力模块(SE模块) ===================== class ChannelAttention(nn.Module): """通道注意力模块(Squeeze-and-Excitation)""" def __init__(self, in_channels, reduction_ratio=16): """ 参数: in_channels: 输入特征图的通道数 reduction_ratio: 降维比例,用于减少参数量 """ super(ChannelAttention, self).__init__() # 全局平均池化 - 将空间维度压缩为1x1,保留通道信息 self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全连接层 + 激活函数,用于学习通道间的依赖关系 self.fc = nn.Sequential( # 降维:压缩通道数,减少计算量 nn.Linear(in_channels, in_channels // reduction_ratio, bias=False), nn.ReLU(inplace=True), # 升维:恢复原始通道数 nn.Linear(in_channels // reduction_ratio, in_channels, bias=False), # Sigmoid将输出值归一化到[0,1],表示通道重要性权重 nn.Sigmoid() ) def forward(self, x): """ 参数: x: 输入特征图,形状为 [batch_size, channels, height, width] 返回: 加权后的特征图,形状不变 """ batch_size, channels, height, width = x.size() # 1. 全局平均池化:[batch_size, channels, height, width] → [batch_size, channels, 1, 1] avg_pool_output = self.avg_pool(x) # 2. 展平为一维向量:[batch_size, channels, 1, 1] → [batch_size, channels] avg_pool_output = avg_pool_output.view(batch_size, channels) # 3. 通过全连接层学习通道权重:[batch_size, channels] → [batch_size, channels] channel_weights = self.fc(avg_pool_output) # 4. 重塑为二维张量:[batch_size, channels] → [batch_size, channels, 1, 1] channel_weights = channel_weights.view(batch_size, channels, 1, 1) # 5. 将权重应用到原始特征图上(逐通道相乘) return x * channel_weights # 输出形状:[batch_size, channels, height, width]

通道注意力模块的核心原理

(1)Squeeze(压缩):

- 通过全局平均池化将每个通道的二维特征图(H×W)压缩为一个标量,保留通道的全局信息。

- 物理意义:计算每个通道在整个图像中的 “平均响应强度”,例如,“边缘检测通道” 在有物体边缘的图像中响应值会更高。

(2)Excitation(激发):

- 通过全连接层+Sigmoid激活,学习通道间的依赖关系,输出0-1之间的权重值。

- 物理意义:让模型自动判断哪些通道更重要(权重接近 1),哪些通道可忽略(权重接近 0)。

(3)Reweight(重加权):

- 将学习到的通道权重与原始特征图逐通道相乘,增强重要通道,抑制不重要通道。

- 物理意义:类似人类视觉系统聚焦于关键特征(如猫的轮廓),忽略无关特征(如背景颜色)

通道注意力插入后,参数量略微提高,增加了特征提取能力

2.模型的重新定义(通道注意力的插入)

class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() # ---------------------- 第一个卷积块 ---------------------- self.conv1 = nn.Conv2d(3, 32, 3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.relu1 = nn.ReLU() # 新增:插入通道注意力模块(SE模块) self.ca1 = ChannelAttention(in_channels=32, reduction_ratio=16) self.pool1 = nn.MaxPool2d(2, 2) # ---------------------- 第二个卷积块 ---------------------- self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.relu2 = nn.ReLU() # 新增:插入通道注意力模块(SE模块) self.ca2 = ChannelAttention(in_channels=64, reduction_ratio=16) self.pool2 = nn.MaxPool2d(2) # ---------------------- 第三个卷积块 ---------------------- self.conv3 = nn.Conv2d(64, 128, 3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.relu3 = nn.ReLU() # 新增:插入通道注意力模块(SE模块) self.ca3 = ChannelAttention(in_channels=128, reduction_ratio=16) self.pool3 = nn.MaxPool2d(2) # ---------------------- 全连接层(分类器) ---------------------- self.fc1 = nn.Linear(128 * 4 * 4, 512) self.dropout = nn.Dropout(p=0.5) self.fc2 = nn.Linear(512, 10) def forward(self, x): # ---------- 卷积块1处理 ---------- x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.ca1(x) # 应用通道注意力 x = self.pool1(x) # ---------- 卷积块2处理 ---------- x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.ca2(x) # 应用通道注意力 x = self.pool2(x) # ---------- 卷积块3处理 ---------- x = self.conv3(x) x = self.bn3(x) x = self.relu3(x) x = self.ca3(x) # 应用通道注意力 x = self.pool3(x) # ---------- 展平与全连接层 ---------- x = x.view(-1, 128 * 4 * 4) x = self.fc1(x) x = self.relu3(x) x = self.dropout(x) x = self.fc2(x) return x # 重新初始化模型,包含通道注意力模块 model = CNN() model = model.to(device) # 将模型移至GPU(如果可用) criterion = nn.CrossEntropyLoss() # 交叉熵损失函数 optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器 # 引入学习率调度器,在训练过程中动态调整学习率--训练初期使用较大的 LR 快速降低损失,训练后期使用较小的 LR 更精细地逼近全局最优解。 # 在每个 epoch 结束后,需要手动调用调度器来更新学习率,可以在训练过程中调用 scheduler.step() scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, # 指定要控制的优化器(这里是Adam) mode='min', # 监测的指标是"最小化"(如损失函数) patience=3, # 如果连续3个epoch指标没有改善,才降低LR factor=0.5 # 降低LR的比例(新LR = 旧LR × 0.5) ) # 训练模型(复用原有的train函数) print("开始训练带通道注意力的CNN模型...") final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs=50) print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

@浙大疏锦行

http://www.jsqmd.com/news/147548/

相关文章:

  • 教育场景应用:基于TensorFlow的在线编程实验平台
  • 【Open-AutoGLM高效定位秘诀】:90%工程师忽略的4个关键参数配置
  • 如何快速制作专业地图演示:免费矢量素材完整指南
  • 为什么顶尖团队都在抢用Open-AutoGLM?,它到底解决了哪些测试痛点
  • SeedVR2视频修复完整指南:告别Sora2模糊画质的终极方案
  • 2026餐饮老板:EMBA太贵,AI太深,一张证书能补课吗?
  • 树莓派5引脚定义与继电器模块连接实践指南
  • 树莓派插针定义图文说明:4B版本复位引脚详解
  • ESP32教程:Arduino IDE控制舵机角度精准调节实践
  • 知乎专栏写作:发布高质量TensorFlow技术问答
  • 树莓派连接Home Assistant入门必看指南
  • 视频修复新利器:SeedVR2实战应用全解析
  • Open-LLM-VTuber完整指南:打造你的专属AI虚拟主播
  • 基于Vue3与Three.js的3D球体抽奖系统技术解析
  • Open-AutoGLM启动卡在第一步?这7个预检项你必须立即检查
  • PingFangSC字体包:免费开源跨平台字体解决方案终极指南
  • 树莓派烧录批量部署:多卡同步写入实战案例
  • 3D抽奖系统终极指南:5分钟快速搭建企业级互动平台
  • 终极指南:在Windows 7上安装Python 3.9+的完整教程
  • 2025年保定靠谱精准营销服务商排行榜,河北集创市场口碑如何? - 工业推荐榜
  • LongCat-Video:13.6亿参数开源视频生成模型,5分钟长视频创作革命
  • Kubeadm安装K8S集群
  • Real-ESRGAN终极指南:三步实现图片视频智能修复
  • 2025年年终膜结构厂家推荐:从设计能力到施工团队的专业维度对比与5家高口碑厂家聚焦 - 品牌推荐
  • 读共生:4_0时代的人机关系02人机合作后
  • 大文件处理利器:TFRecord格式设计与优化建议
  • 单点登录集成:OAuth2.0接入TensorFlow Web门户
  • 2025年上海网站建设十大品牌权威评测 - 行业调查分析报告 - 匠子网络
  • 2025年企业展厅设计公司推荐,技术先进的企业展厅设计服务公司全解析 - 工业品牌热点
  • 2025年北京婚内财产协议律师联系方式汇总: 核心城区资深律师联系通道与高效咨询指引 - 十大品牌推荐