对于 [[函数#^f5ba8e|Softmax]] 函数,如果引入温度 \(T\),那么:
\[q_i = \frac {\exp (z_i / T)}{\sum \exp (z_j / T)}
\]
会使得 \(q_i\) 间的差距变小,也就是会放大真实标签小的权重,使得其显现化。
于是对于两个模型 \(\theta, \phi\),如果有着相同的 logits 输出空间 \(l_\theta, l_\phi\),最终的,那么就可以在一个共有数据集下使用这样一个损失函数:
\[L(\phi; \theta) = \alpha L(\phi) + \beta D_{KL} ({\rm softmax}_T(l_\theta) \| {\rm softmax}_T(l_\phi))
\]
基于这个损失函数训练 \(\phi\) 即可。
值得注意的是,由于对于 \({\rm softmax}_T\),其导数会多一个 \(\frac 1 T\) 的常数,另外,通过 \(T\) 较大时:\(\exp \frac x T \approx 1 + \frac x T\),使得对于求 \(\rm softmax\) 偏导后的 \(q_{\theta, i} - q_{\phi, i}\) 会变成原本的 \(O(\frac 1 T)\) 量级,也就是 \(\nabla D_{KL}\) 会变成原本的 \(O(\frac 1 {T^2})\) 量级。所以可以将公式修正为 \(\beta' = \beta T^2\)
一种特殊情形: 直接学习 logits (不经过softmax),并利用 MSE 设计 Loss。
本质上和利用 KL 散度是一致的,可以考虑其偏导都是实际输出的差值,也就是说,在 logits 分布一致的情况下,直接学习 logits 完全可行。
但是如果分布一个大,一个小,那么这样学习可能就会很困难。
