数据无关知识蒸馏(Data-Free Knowledge Distillation, DFKD),解决在没有或者少有真实数据的情况下,怎么蒸馏某个模型的问题。
[[KD 基于高温 Softmax 的 Logits 模拟]]
设计思想
DFKD将蒸馏转化为逆向生成问题,也就是如何通过模型参数或者相关的东西还原原始数据的流行分布。以模拟教师网络的内部信号(如BatchNorm统计量、权重先验、特征协方差或输出置信度)作为优化目标,训练一个轻量生成器逼近真实数据分布;随后在合成样本上执行标准 KD(Logits/特征对齐),实现知识传递。该思想规避了数据版权与隐私壁垒,使模型压缩与迁移完全解耦于原始语料。
算法流程
- 生成器初始化:以高斯噪声或低维隐变量为输入,构建轻量生成网络(根据神秘言论,早期多用 GAN 架构,近年引入扩散先验或自回归Token生成器)。
- 伪数据优化(固定教师,更新生成器):
最小化复合损失 \(\mathcal{L}_{gen} = \lambda_1 \mathcal{L}_{BN} + \lambda_2 \mathcal{L}_{feat} + \lambda_3 \mathcal{L}_{prior} - \lambda_4 \mathcal{L}_{student} + \cdots\)- \(\mathcal{L}_{BN}\):匹配教师各层BatchNorm的均值/方差,约束生成样本的统计特性。
- \(\mathcal{L}_{feat}\):约束特征图稀疏性、Gram矩阵或注意力分布,逼近真实激活流形。(这一点不是很明白)
- \(\mathcal{L}_{prior}\):注入模态先验(如图像Total Variation正则、文本Token频率分布),防止生成器退化。
- \(\cdots\):可以有更多的衡量标准
- 学生蒸馏(固定生成器与教师,更新学生):
- 在合成批次上计算标准KD损失 \(\mathcal{L}_{student} = \alpha \cdot \text{KL}(p_T^\tau \| p_S^\tau) + \beta \cdot \text{MSE}(F_T, F_S)\),交替迭代直至学生性能收敛。
- 还有其他的判断方法,反正大概是通过数学手段,验证学生模型和教师模型的流形空间是否一致,验证的方法就是尽可能利用其中的各种参数。
局限
- 生成器易陷入模式崩溃(Mode Collapse),伪数据难以覆盖长尾/多模态分布,学生泛化上限受限于合成质量。
- DFKD 高度依赖 BatchNorm,但是现在 LLM 框架多采用 LayerNorm。根据神秘言论,需重构损失(如基于注意力图对齐、Token熵匹配或Prompt引导生成),工程复杂度陡增。
- 生成器训练与交替优化会显著提高计算成本,且对超参数敏感,需精细调优。
