从ViT到PVT:SRA模块如何让Transformer在CV任务上‘瘦身’成功?
从ViT到PVT:SRA模块如何让Transformer在CV任务上‘瘦身’成功?
当Vision Transformer(ViT)首次将Transformer架构引入计算机视觉领域时,它展现出了媲美甚至超越传统CNN的性能。然而,随着图像分辨率的提升,标准ViT模型的计算复杂度呈平方级增长,这让许多研究者开始思考:能否在保持全局感受野的同时,让Transformer在视觉任务中变得更高效?Pyramid Vision Transformer(PVT)给出的答案是SRA(Spatial Reduction Attention)模块——这个看似简单的设计,却让Transformer在CV任务上成功"瘦身"。
1. ViT的计算瓶颈与PVT的诞生
2017年Transformer在NLP领域大获成功后,研究者们开始探索其在计算机视觉中的应用可能。2020年,Vision Transformer(ViT)首次证明,纯Transformer架构在图像分类任务上可以达到与CNN相当甚至更好的性能。然而,ViT在处理高分辨率图像时面临严峻挑战:
- 计算复杂度问题:标准自注意力机制的计算复杂度与图像patch数量的平方成正比
- 内存占用过高:高分辨率图像会导致注意力矩阵变得极其庞大
- 特征金字塔缺失:ViT缺乏CNN那种天然的多尺度特征表示能力
Pyramid Vision Transformer(PVT)的提出正是为了解决这些问题。与ViT的"平坦"结构不同,PVT引入了类似CNN的金字塔结构,通过四个stage逐步下采样特征图。但PVT最关键的创新在于其注意力机制的设计——Spatial Reduction Attention(SRA)模块。
2. SRA模块的核心设计思想
SRA模块的精妙之处在于它重新思考了自注意力机制中的K(key)和V(value)矩阵的作用。传统多头注意力中,Q、K、V三者通常保持相同的维度,这导致了不必要的计算冗余。SRA通过两个关键设计实现了效率提升:
2.1 空间缩减策略
SRA模块的核心创新是对K和V矩阵进行空间维度的缩减:
# 传统多头注意力 q = query @ W_q # [N, C] k = key @ W_k # [N, C] v = value @ W_v # [N, C] # SRA中的处理 k = spatial_reduction(key) @ W_k # [N/R, C] v = spatial_reduction(value) @ W_v # [N/R, C]其中spatial_reduction可以通过卷积或池化操作实现,缩减比例R通常设置为4或8。这种设计带来了显著的效率提升:
| 操作 | 计算复杂度 | 内存占用 |
|---|---|---|
| 标准注意力 | O(N²) | O(N²) |
| SRA (R=4) | O(N²/16) | O(N²/16) |
| SRA (R=8) | O(N²/64) | O(N²/64) |
2.2 保持输出维度不变
尽管对K和V进行了降维处理,SRA通过巧妙的矩阵运算保持了输出维度与标准注意力一致:
- Q矩阵保持原始维度 [N, C]
- 降维后的K矩阵为 [N/R, C]
- 注意力分数计算为 QK^T → [N, N/R]
- 与降维后的V [N/R, C]相乘 → [N, C]
这种设计确保了SRA模块可以无缝替换标准注意力,而不会影响下游网络结构。
3. SRA在PVT中的实际应用效果
PVT将SRA模块应用于其金字塔结构的每个stage中,实现了计算效率与模型性能的平衡。在实际CV任务中,SRA带来了显著优势:
3.1 图像分类任务
在ImageNet分类任务上,PVT展现了与CNN相当的性能,同时计算量大幅降低:
| 模型 | Top-1 Acc | FLOPs | 参数量 |
|---|---|---|---|
| ResNet-50 | 76.2% | 4.1G | 25M |
| ViT-B/16 | 77.9% | 17.6G | 86M |
| PVT-Small | 79.8% | 3.8G | 24M |
| PVT-Medium | 81.2% | 6.7G | 44M |
3.2 目标检测与分割
SRA的另一个重要优势是支持密集预测任务。在COCO目标检测和ADE20K分割任务中,PVT作为backbone展现出强大性能:
目标检测结果(RetinaNet框架)
| Backbone | AP@0.5 | AP@0.75 | AP@[0.5:0.95] |
|---|---|---|---|
| ResNet-50 | 50.9 | 34.7 | 36.3 |
| PVT-Small | 53.1 | 36.2 | 38.2 |
| PVT-Medium | 54.7 | 37.5 | 39.5 |
注意:SRA模块特别适合处理高分辨率特征图,这使得PVT在密集预测任务中优势明显
4. SRA的演进与优化
PVT团队在后续工作中持续优化SRA模块,主要体现在两个方向:
4.1 从卷积到池化的演进
PVT v2中将SRA中的空间缩减操作从卷积改为池化,进一步减少了参数量:
- 卷积实现:需要学习卷积核参数
- 池化实现:无参数操作,计算更高效
这种改变使得PVT v2在保持性能的同时,模型更加轻量。
4.2 多任务适应性改进
研究者们发现SRA模块在不同计算机视觉任务中表现有所差异:
- 分类任务:较大的缩减比例(R=8)效果更好
- 检测任务:中等缩减比例(R=4)更合适
- 分割任务:需要平衡感受野与细节保持
这种观察促使后续工作开发了动态调整缩减比例的自适应SRA变体。
5. SRA对Transformer架构的启示
SRA模块的成功为Transformer在CV领域的优化提供了重要启示:
- 不是所有注意力都需要完整计算:适当降维可以大幅提升效率
- 空间局部性在视觉任务中很重要:这与NLP中的全局注意力形成对比
- 金字塔结构在视觉Transformer中很有效:多尺度特征表示仍然关键
在实际项目中应用PVT时,有几个经验值得分享:当处理高分辨率图像(如医疗影像或卫星图像)时,适当增大初始阶段的缩减比例可以显著降低内存消耗;而在需要精细定位的任务中,最后阶段的缩减比例不宜过大。
