07 DeiT 论文精读:Training data-efficient image transformers distillation through attention
前言
在前面的章节中,我们已经理解了 ViT 的核心思想:
图像 → Patch Embedding → Token 序列 → Transformer Encoder → 分类结果
ViT 原论文证明了一件非常重要的事情:纯 Transformer 架构可以直接用于图像识别任务。
但是 ViT 也留下了一个明显问题:
ViT 很强,但它对大规模数据和训练资源的依赖比较明显。原始 ViT 的强性能通常依赖大规模预训练数据和较高训练成本,这对于普通实验室和普通研究者并不友好。DeiT 正是围绕这个问题提出的。DeiT 论文全名为 Training>论文中明确指出,DeiT 在只使用 ImageNet 的情况下训练了有竞争力的无卷积 Transformer,并且引入了适合 Transformer 的 teacher-student 蒸馏策略,其中 distillation token 是核心设计。
2. DeiT 论文想解决什么问题?
DeiT 的提出背景非常明确。ViT 原论文证明了纯 Transformer 可以用于图像识别,但原始 ViT 的强性能很大程度上依赖大规模预训练数据。例如,ViT 原论文中经常使用 ImageNet-21k、JFT-300M 等大规模数据。对于大公司或大规模算力平台来说,这种设定可以接受;但对于普通研究者来说,训练成本过高。所以 DeiT 想回答的问题是:
如果只使用 ImageNet-1K,不使用额外大规模外部数据,能不能训练出性能强的 Vision Transformer?
论文摘要中提到,DeiT 的 reference vision transformer 拥有约 86M 参数,在不使用外部数据的情况下达到 83.1% ImageNet Top-1 accuracy;论文还报告了通过蒸馏后最高可达到 85.2% Top-1 accuracy。因此,DeiT 的核心问题不是:Transformer 能不能做图像分类?
这个问题 ViT 已经回答了。DeiT 真正要解决的是:
Transformer 能不能在有限数据和有限算力条件下训练好?这就是 DeiT 中 “data-efficient” 的含义。
3. DeiT 和 ViT 的关系
DeiT 和 ViT 的关系非常紧密。ViT 的核心结构是:
Patch Embedding
Class Token
Position Embedding
Transformer Encoder
Classification Head
DeiT 并没有推翻这个框架。相反,DeiT 基本沿用了 ViT 的主体结构。它的重点不是发明一种全新的视觉 Transformer 架构,而是解决 ViT 的训练效率问题。可以这样理解:
ViT:证明图像可以被看作 patch token 序列,并输入 Transformer。
DeiT:证明在不依赖超大规模外部数据的情况下,ViT 也可以通过更好的训练策略和蒸馏机制训练得很好。
所以,DeiT 是 ViT 之后非常关键的一篇工作。它解决的是 ViT 从“能跑通”到“更容易训练、更容易复现、更适合普通实验条件”的问题。
4. DeiT 的核心贡献
DeiT 的贡献可以概括为三点。
4.1 提出一套更高效的 ViT 训练方案
DeiT 证明,只使用 ImageNet-1K,也可以训练出很强的 Vision Transformer。论文强调其方法可以在单台计算机上较短时间内训练出有竞争力的模型,这大大降低了 ViT 的使用门槛。这说明 ViT 的性能不仅取决于模型结构,也强烈依赖训练策略。
DeiT 中的 “data-efficient” 可以从两个层面理解。
第一层含义是:不依赖超大外部数据。
ViT 原论文中,大规模预训练是非常重要的。DeiT 则希望在 ImageNet-1K 这样的标准数据规模下训练 Transformer。
第二层含义是:在有限的数据下提高训练效果
也就是说,同样只使用 ImageNet-1K,DeiT 通过更强的数据增强、正则化、优化策略和蒸馏机制,让 ViT 学得更好。所以 DeiT 的目标不是简单地减少数据量,而是提高数据使用效率。可以概括为:
ViT 依赖大规模数据学习视觉规律;
DeiT 通过训练策略和 teacher supervision 提高 ViT 的数据效率。
DeiT 的另一个重要观点是:
对于 Vision Transformer,训练 recipe 和模型结构同样重要。这里的训练 recipe 可以理解为一整套训练配置,包括:
数据增强
正则化
优化器
学习率策略
warmup
label smoothing
mixup
cutmix
random erasing
stochastic depth
知识蒸馏
这些细节对 ViT 尤其重要。原因是 ViT 的图像归纳偏置比 CNN 更弱。CNN 天然具有局部连接、权重共享和平移等变性,而 ViT 更依赖数据和训练目标自己学习图像结构。因此,如果训练策略不够强,ViT 在 ImageNet-1K 上可能训练不充分。DeiT 的启发是:
ViT 的问题不只是结构问题,也是训练问题。
4.2 引入适合 Transformer 的蒸馏机制
传统知识蒸馏通常是让 student 模型学习 teacher 模型的输出分布。DeiT 的创新在于:
不是只在 loss 上做蒸馏, 而是在 Transformer 输入序列中加入一个 distillation token。这个 distillation token 会和 class token、patch token 一起进入 Transformer Encoder,通过 self-attention 参与表示学习。这也是论文标题中:distillation through attention的含义。
要理解 DeiT,必须先理解知识蒸馏。知识蒸馏的基本框架是:
Teacher Model:一个已经训练好的强模型
Student Model:一个需要训练的模型
普通监督学习中,student 只学习真实标签:image → label。而知识蒸馏中,student 还要学习 teacher 的输出:image → teacher prediction。teacher 的输出通常比 one-hot 标签包含更多信息。例如,一张猫的图片,真实标签只告诉模型:
cat = 1
其他类别 = 0
但是 teacher 的输出可能是:
cat = 0.82
tiger = 0.08
dog = 0.04
fox = 0.02
...
这种分布表达了类别之间的相似关系。所以蒸馏的意义是
真实标签告诉 student 正确答案;
teacher 输出告诉 student 类别之间的关系和判断倾向。
这对 ViT 很有帮助,因为 ViT 缺少 CNN 那种强图像先验,需要更丰富的监督信号。
4.3 证明 ConvNet Teacher 对 ViT Student 特别有效
DeiT 论文中指出,使用 ConvNet 作为 teacher 对 Transformer student 的蒸馏尤其有效。直观上,CNN 具有更强的局部视觉归纳偏置,而 ViT 缺少这种先验。因此,CNN teacher 可以向 ViT student 传递有用的视觉判断经验。这也从侧面说明:ViT 的训练困难并不是结构无效,而是它需要更好的监督信号和训练策略来学习视觉规律。
5. DeiT 的关键设计:Distillation Token
DeiT 最核心的结构设计就是:
distillation token在标准 ViT 中,输入序列是:
[CLS], patch_1, patch_2, ..., patch_196而在 DeiT 中,输入序列变成:
[CLS], [DIST], patch_1, patch_2, ..., patch_196其中:
[CLS]:class token,用于学习真实标签监督 [DIST]:distillation token,用于学习 teacher 监督这意味着 DeiT 比 ViT 多了一个特殊 token。对于 224×224 输入、patch size 为 16 的模型:
ViT token 数量: 196 patch tokens + 1 class token = 197 DeiT distilled token 数量: 196 patch tokens + 1 class token + 1 distillation token = 198这个 distillation token 不是图像 patch 切出来的,而是一个可学习参数,和 class token 一样会参与 Transformer Encoder 中的 self-attention。
5.1 Class Token 和 Distillation Token 的区别
class token 和 distillation token 很像,但目标不同。class token 主要用于真实标签分类。它经过 Transformer Encoder 后,接分类头输出:class head output。然后和真实标签计算分类损失。可以理解为:CLS token 负责学习 ground-truth label。distillation token 主要用于学习 teacher。它经过 Transformer Encoder 后,接另一个分类头:distillation head output。然后和 teacher 输出计算蒸馏损失。可以理解为:DIST token 负责学习 teacher prediction。
5.2 二者会不会互相影响?
会。因为它们都在同一个 Transformer Encoder 中。输入序列是:
[CLS], [DIST], patch_1, patch_2, ..., patch_196
在 self-attention 中,每个 token 都可以和其他 token 交互。所以:
CLS token 可以关注 patch token;
DIST token 可以关注 patch token;
CLS token 和 DIST token 之间也可以互相关注。
这就是 DeiT 中 “through attention” 的关键。distillation token 不是一个独立分支,而是通过 attention 融入整个 Transformer 表示学习过程。
6. DeiT 的模型结构解析
DeiT 的整体结构可以写成:
Input Image
↓
Patch Embedding
↓
Add Class Token
↓
Add Distillation Token
↓
Add Position Embedding
↓
Transformer Encoder
↓
取 CLS token 和 DIST token
↓
Class Head + Distillation Head
和 ViT 相比,主要变化有三个:
1. 多了 distillation token;
2. position embedding 长度增加 1;
3. 多了 distillation head。
以 DeiT-B/16 为例:
patch tokens: 196
class token: 1
distillation token: 1
total tokens: 198
embedding dim: 768
所以 Transformer Encoder 的输入形状为:[B, 198, 768]
而不是 ViT-B/16 的:[B, 197, 768]
最后输出时:
x_cls = x[:, 0]
x_dist = x[:, 1]
其中:
x_cls:class token 输出
x_dist:distillation token 输出
分别接不同的分类头。
7. DeiT 的蒸馏损失函数解析
DeiT 的训练目标由两部分组成:
1. 普通分类损失
2. 蒸馏损失
可以写成:
Total Loss = (1 - α) × Classification Loss + α × Distillation Loss
其中:α:控制普通分类损失和蒸馏损失的权重
在官方 DeiT 代码中,DistillationLoss 会先计算基础分类损失 base_loss,如果启用蒸馏,则再用 teacher model 对原始输入进行预测,然后根据 distillation_type 计算 soft 或 hard distillation loss,最后按 alpha 加权组合。
普通分类损失通常是:Classification Loss = CE(class_head_output, ground_truth_label)
蒸馏损失根据蒸馏方式不同,可以分为 soft distillation 和 hard distillation。
7.1 Soft Distillation
Soft distillation 学习 teacher 的概率分布。teacher 会输出每个类别的概率,例如:
cat: 0.82
dog: 0.08
tiger: 0.04
...
student 的 distillation head 要尽量接近 teacher 的输出分布。通常使用 KL divergence 进行约束。
形式上可以理解为:Distillation Loss = KL(student_distribution, teacher_distribution)
在官方 DeiT 代码中,soft distillation 使用 F.kl_div,对 student 的 distillation 输出和 teacher 输出都进行 temperature scaling,并乘以 T*T 进行尺度修正。也就是说,soft distillation 学的是:teacher 的完整判断分布。它不仅告诉 student 哪个类别最可能,还告诉 student 类别之间的相似关系。
7.2 Hard Distillation
Hard distillation 学习 teacher 的预测类别。teacher 输出 logits 后,取最大概率对应的类别:
teacher_label = argmax(teacher_output)。然后 student 的 distillation head 学习这个 teacher label。官方 DeiT 代码中,hard distillation 使用:
F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
也就是说,hard distillation 不是学习完整概率分布,而是学习 teacher 给出的硬标签。可以这样理解:
Soft distillation:学习 teacher 的判断分布
Hard distillation:学习 teacher 的最终答案
7.3 teacher model
在 DeiT 中,teacher 指的是知识蒸馏中的教师模型。DeiT 本身是 student model,训练时不仅学习真实标签,还会学习 teacher model 的预测结果。
DeiT 原论文中默认使用的 teacher 是 RegNetY-16GF,这是一个卷积神经网络模型,参数量约为 84M。论文中使用与 DeiT 相同的数据和数据增强方式训练该 teacher,其 ImageNet Top-1 accuracy 为 82.9%。
DeiT 之所以选择 CNN teacher,是因为 CNN 具有更强的图像归纳偏置,例如局部连接、层次化特征提取和局部纹理建模能力。ViT 的归纳偏置较弱,通过 distillation token 向 CNN teacher 学习,可以帮助 Transformer student 获得更好的视觉监督信号。
