当前位置: 首页 > news >正文

从ResNet到ResNeSt:手把手教你用PyTorch复现Split-Attention注意力机制

从ResNet到ResNeSt:手把手教你用PyTorch复现Split-Attention注意力机制

在计算机视觉领域,注意力机制已经成为提升模型性能的关键技术。ResNeSt作为ResNet的改进版本,通过引入Split-Attention机制,在保持ResNet简洁架构的同时,显著提升了特征表达能力。本文将深入解析Split-Attention的实现细节,带你从零开始用PyTorch实现这一创新模块。

1. Split-Attention核心原理剖析

Split-Attention的核心思想是将特征图在通道维度上进行多级分组,并在不同组之间建立注意力交互。这种设计既保留了分组卷积的计算效率,又通过注意力机制增强了特征表达能力。

具体来说,Split-Attention包含三个关键步骤:

  1. 基数分组(Cardinal Groups):将输入特征图划分为K个基数组
  2. 径向划分(Radix Splits):在每个基数组内进一步划分为R个子组
  3. 注意力融合:基于全局上下文信息计算各子组的注意力权重

这种双重分组结构可以用以下公式表示:

总分组数 = 基数(K) × 径向数(R)

在PyTorch中,我们可以通过group参数实现基数分组,而径向划分则需要更精细的张量操作。下面是一个简单的分组示意图:

操作步骤输入形状输出形状说明
基数分组(B,C,H,W)(B,K,C/K,H,W)沿通道维度分组
径向划分(B,K,C/K,H,W)(B,K,R,C/(KR),H,W)每组内再划分
注意力计算(B,K,R,C/(KR),H,W)(B,K,C/K,H,W)加权融合

2. RadixSoftmax模块实现

RadixSoftmax是Split-Attention的核心组件,负责计算各子组的注意力权重。与常规Softmax不同,它需要在特定维度上进行归一化。

class RadixSoftmax(nn.Module): def __init__(self, radix, cardinality): super().__init__() self.radix = radix # 每个基数组下的子组数 self.cardinality = cardinality # 基数组数量 def forward(self, x): batch = x.size(0) if self.radix > 1: # 将输入重塑为(B, K, R, C/(KR))形式 x = x.view(batch, self.cardinality, self.radix, -1) # 在径向维度(R)上计算Softmax x = F.softmax(x, dim=2) x = x.reshape(batch, -1) else: x = torch.sigmoid(x) return x

这个实现有几个关键点:

  1. 当radix=1时退化为Sigmoid,相当于SE模块的注意力机制
  2. 通过view和reshape操作实现张量的高效重组
  3. Softmax仅在径向维度计算,保持基数组间的独立性

3. SplitAttn模块完整实现

基于RadixSoftmax,我们可以构建完整的SplitAttn模块。以下是逐步实现过程:

class SplitAttn(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, radix=2, groups=1, norm_layer=nn.BatchNorm2d): super().__init__() out_channels = out_channels or in_channels self.radix = radix # 中间通道数 = 输出通道 × radix mid_chs = out_channels * radix # 注意力计算通道数 attn_chs = max(in_channels * radix // 8, 32) # 主卷积路径 self.conv = nn.Conv2d( in_channels, mid_chs, kernel_size, stride=stride, padding=kernel_size//2, groups=groups * radix, bias=False) self.bn0 = norm_layer(mid_chs) self.act0 = nn.ReLU(inplace=True) # 注意力路径 self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) self.bn1 = norm_layer(attn_chs) self.act1 = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) self.rsoftmax = RadixSoftmax(radix, groups)

前向传播过程需要仔细处理张量形状变换:

def forward(self, x): x = self.conv(x) x = self.bn0(x) x = self.act0(x) B, RC, H, W = x.shape if self.radix > 1: # 将特征图拆分为radix个子组 x = x.view(B, self.radix, RC//self.radix, H, W) # 对各子组特征求和 x_gap = x.sum(dim=1) else: x_gap = x # 计算全局平均池化 x_gap = x_gap.mean([2,3], keepdim=True) # 计算注意力权重 x_attn = self.fc1(x_gap) x_attn = self.bn1(x_attn) x_attn = self.act1(x_attn) x_attn = self.fc2(x_attn) x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) # 应用注意力权重 if self.radix > 1: out = (x * x_attn.view(B, self.radix, RC//self.radix, 1, 1)).sum(dim=1) else: out = x * x_attn return out.contiguous()

4. ResNeSt Bottleneck集成

将SplitAttn集成到ResNet的Bottleneck中,形成完整的ResNeSt模块:

class ResNestBottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, radix=2, cardinality=1, base_width=64): super().__init__() group_width = int(planes * (base_width / 64.)) * cardinality self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(group_width) self.act1 = nn.ReLU(inplace=True) self.conv2 = SplitAttn( group_width, group_width, kernel_size=3, stride=stride, radix=radix, groups=cardinality) self.conv3 = nn.Conv2d(group_width, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.act3 = nn.ReLU(inplace=True) self.downsample = downsample

前向传播保持了ResNet的经典残差结构:

def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.act1(out) out = self.conv2(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.act3(out) return out

5. 实战技巧与性能优化

在实际实现ResNeSt时,有几个关键点需要注意:

  1. 基数与径向数的选择

    • 基数(Cardinality)通常设置为1或2
    • 径向数(Radix)常用值为2或4
    • 两者乘积不宜过大,否则会显著增加计算量
  2. 内存优化

    • 使用contiguous()确保张量内存布局连续
    • 合理设置groups参数利用分组卷积优化
  3. 训练技巧

    • 学习率warmup有助于稳定训练
    • 标签平滑(Label Smoothing)可以提升泛化能力
    • 大型batch训练时需要调整BN参数

以下是不同配置下的计算量对比:

模型Params(M)FLOPs(G)Top-1 Acc(%)
ResNet-5025.54.176.2
ResNeSt-50 (radix=2)27.54.378.3
ResNeSt-50 (radix=4)30.14.779.1

在实现过程中,我发现最易出错的地方是张量形状变换。特别是在SplitAttn模块中,需要确保:

  1. 分组卷积的groups参数正确设置为cardinality × radix
  2. 注意力权重的计算与原始特征图维度匹配
  3. 残差连接时的通道数对齐

一个实用的调试技巧是添加shape检查断言:

assert x.shape == (B,C,H,W), f"Expected shape {(B,C,H,W)}, got {x.shape}"

通过PyTorch的灵活张量操作,我们可以高效实现Split-Attention机制。相比原始论文的TensorFlow实现,PyTorch版本通常能获得更好的运行时性能,特别是在使用torch.jit.script优化后。

http://www.jsqmd.com/news/676620/

相关文章:

  • 3步实现AI到PSD完美转换:Ai2Psd脚本终极指南
  • 官方认证|2026年五大正规番禺驾校排名,广州随约驾驶学校有限公司口碑断层领先 - 博客万
  • Mac用户终极抢票指南:如何用12306ForMac轻松搞定春运车票 [特殊字符]
  • 压力机振动危害与科学治理科普
  • 从‘dangerous relocation’报错,聊聊AArch64架构下静态库与动态库混用的那些坑
  • 深度分析知名的加拿大海运企业,乐成国际物流靠谱之选 - myqiye
  • FUXA:基于Web的工业可视化系统,从零构建专业级监控平台
  • VS2019配置libxl库踩坑实录:从‘无法解析的外部符号’到成功生成Excel文件
  • 一劳永逸解决Windows和Office激活难题:KMS智能激活终极方案
  • UnrealPakViewer:5个关键技巧帮你轻松管理虚幻引擎Pak文件资源
  • 避坑指南:Unity阿拉伯语适配中那些‘看起来对但实际是错’的显示问题
  • AI专著撰写秘籍!AI写专著工具助力,3天完成20万字专著写作!
  • 云原生安全与合规:OPA Gatekeeper + Kyverno + Trivy 实战指南(建议收藏)
  • PyTorch张量操作保姆级教程:从arange创建到广播机制,新手避坑指南
  • 信号处理中的插值与采样技术详解
  • 2026年衬塑设备制造商中如皋佳百费用如何,听听用户评价 - 工业推荐榜
  • 告别轮询:用ibv_req_notify_cq和事件驱动优化你的RDMA应用性能
  • 【Matlab代码】基于SCSSA-CNN-BiGRU-Attention(改进麻雀搜索算法优化双向门控循环单元网络)多变量回归预测
  • PinWin:你的窗口为何总被遮挡?这款开源神器让重要信息永不消失
  • 超越默认样式:手把手教你用mplfinance定制专属量化图表风格(从配色到字体)
  • M62429L双声道音量IC驱动:从硬件引脚到软件时序的实战解析
  • 别再死记硬背了!用Python+Jupyter Notebook手把手教你计算化学反应吉布斯自由能变
  • 【ArcGIS Pro二次开发】:三调地类面积精准统计与数据清洗实战
  • 5分钟搞定OFD转PDF:开源神器Ofd2Pdf终极使用指南
  • USB PD PPS便携电源设计:原理与工程实践
  • VHDL并发信号赋值与BLOCK语句实战解析
  • 齿轮箱零部件及其装配质检中的TVA技术突破(18)
  • 聊聊不错的转接线厂家,钦利发口碑如何? - 工业品网
  • MATLAB绘图避坑:箭头颜色总是不对?一文搞懂arrow3和quiver3的颜色控制机制
  • CodeForces-2168B Locate 题解