别再死记硬背了!用代码拆解ViT和DETR,搞懂Transformer处理图像的真正逻辑
用代码拆解ViT和DETR:Transformer处理图像的底层逻辑
第一次看到Vision Transformer(ViT)和DETR的论文时,我完全被那些数学符号和抽象描述搞晕了。直到有一天,我决定打开PyTorch源码,一行行调试这些模型,才发现原来Transformer处理图像的逻辑如此直观。本文将带你用代码视角重新理解这两个开创性工作,特别是它们如何处理图像这种二维数据。
1. 图像如何变成Transformer的"语言"
传统Transformer是为序列数据设计的,比如文本。但图像本质上是二维像素矩阵,要让Transformer理解图像,首先需要解决数据表示问题。ViT给出的方案简单却有效——将图像切割成小块(patch),然后线性嵌入到向量空间。
import torch import torch.nn as nn class PatchEmbedding(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) # (B, C, H, W) -> (B, E, H/P, W/P) x = x.flatten(2) # (B, E, N) where N = (H*W)/P^2 x = x.transpose(1, 2) # (B, N, E) return x这段代码展示了ViT如何处理一张224×224的RGB图像:
- 使用16×16的卷积核(无重叠)切割图像,得到14×14=196个patch
- 每个patch被展平为16×16×3=768维向量
- 通过线性投影将每个patch映射到embedding空间
提示:虽然使用卷积实现,但这本质上等同于将图像网格化后做线性变换。选择16×16的patch是在计算效率和局部信息保留间的平衡。
2. 位置编码:让Transformer记住空间关系
图像切割成patch后失去了原始空间信息,而视觉任务高度依赖位置关系。Transformer通过位置编码(positional encoding)解决这个问题:
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1)] return x关键点:
- 使用正弦函数生成编码,确保不同位置有唯一标识
- 编码维度与embedding维度相同,可以直接相加
- 在ViT中,每个patch的位置是固定的,因此只需一次计算
有趣的是,DETR采用了可学习的位置编码,这让模型可以自适应地调整位置表示:
# DETR的位置编码实现 self.row_embed = nn.Parameter(torch.rand(50, 256 // 2)) self.col_embed = nn.Parameter(torch.rand(50, 256 // 2))3. DETR的魔法:可学习的目标查询
DETR最令人困惑的部分是那100个"神秘"的query向量。让我们用代码揭开它的面纱:
# DETR decoder初始化 self.query_embed = nn.Embedding(num_queries, hidden_dim) # 通常num_queries=100, hidden_dim=256 # 在forward中: query_embed = self.query_embed.weight.unsqueeze(0).repeat(bs, 1, 1) tgt = torch.zeros_like(query_embed) decoder_output = self.decoder(tgt, memory, pos=pos_embed, query_pos=query_embed)这些query向量本质上是模型需要学习的"问题":
- 每个query负责提出一个目标检测问题(如"图像右下角有车吗?")
- 通过多轮attention机制与图像特征交互,逐步细化预测
- 匈牙利匹配确保每个真实框只对应一个预测
下表对比了ViT和DETR的核心差异:
| 特性 | ViT | DETR |
|---|---|---|
| 输入处理 | 规则网格切割 | CNN特征图展平 |
| 位置编码 | 固定正弦编码 | 可学习的二维编码 |
| Transformer应用 | 纯Encoder架构 | Encoder-Decoder架构 |
| 输出处理 | 分类头 | 预测框+类别 |
| 核心创新点 | 图像序列化 | 端到端目标检测 |
4. 从理论到实践:调试技巧
理解这些模型最好的方式是实际调试。以下是几个实用技巧:
- 可视化attention地图:
# 获取ViT最后一层的attention权重 attn_weights = model.blocks[-1].attn.attn.detach() # 平均所有head的attention avg_attn = attn_weights.mean(dim=1)- 跟踪query变化:
# 在DETR decoder每层后记录query状态 class DebugDecoder(nn.Module): def forward(self, tgt, memory, pos, query_pos): for layer in self.layers: tgt = layer(tgt, memory, pos, query_pos) print(f"Layer {layer}: query norm {tgt.norm()}") return tgt- 简化实验设置:
- 先用1-2张图片调试,观察中间结果
- 尝试减少patch大小或query数量,加速实验
- 使用预训练模型快速验证想法
注意:调试Transformer模型需要大显存,建议使用梯度检查点或降低batch size。
第一次跑通DETR训练时,我观察到那些初始为0的query向量在训练过程中逐渐分化,各自"专攻"不同位置和尺寸的目标检测。这种动态学习过程比任何理论解释都更直观地展示了Transformer的威力。
