Softmax函数大揭秘:从原理到应用,你了解多少?
1. Softmax函数的广泛应用
多分类输出、归一化概率、注意力权重、配分函数等场景都用到Softmax函数。但人们很少真正思考其内部工作原理。
2. Softmax对分布的作用
Softmax函数看似简单,其公式为$\mathrm{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}}$ ,它将任意实数向量转换为0到1之间且总和为1的值,是一个伪概率分布。它把向量映射到概率单纯形,在三维空间中像三角形,高维空间原理类似。在语言模型场景中,它能将无界分数转换为总和为1的概率分布,放大值之间的相对差异,但在需要不确定性估计时可能有问题。
3. 数值稳定性与溢出问题
对于较小输入,简单实现方式效果好,但包含指数的Softmax函数存在溢出问题。如输入 `x = [1000, 1001, 1002]` 会导致 `inf / inf = nan` 。Softmax相互关联特性使问题更糟,一个输入溢出会使所有输出变成NaN。解决方案是移动输入值,利用指数函数性质,将输入值减去最大值,即 `c = -max(x)` ,可避免溢出。在NumPy中实现如下:
def stable_softmax(x, axis=-1): # axis=-1 表示最后一个维度 shift_x = x - np.max(x, axis=axis, keepdims=True) exp_shift_x = np.exp(shift_x) return exp_shift_x / np.sum(exp_shift_x, axis=axis, keepdims=True)4. 雅可比矩阵
Softmax是向量函数,需考虑雅可比矩阵。雅可比矩阵是向量函数所有一阶偏导数组成的矩阵。Softmax将所有维度耦合在一起,增加一个输入会增加自身输出并减少其他输出。
4.1 计算雅可比矩阵
Softmax从 `ℝⁿ` 到 `ℝⁿ` ,雅可比矩阵是 `n x n` 矩阵,元素由$J_{ij} = \frac{\partial \mathrm{softmax}(x_i)}{\partial x_j}$ 给出。当 `i = j` 时,$J_{ii} = \frac{\partial \mathrm{s}(x_i)}{\partial x_i} = \mathrm{s}(x_i) \cdot (1 - \mathrm{s}(x_i))$ ;当 `i != j` 时,$J_{ij} = \frac{\partial \mathrm{s}(x_i)}{\partial x_j} = -\mathrm{s}(x_i) \cdot \mathrm{s}(x_j)$ 。整个雅可比矩阵可写成$J_{ij} = \mathrm{diag}(s) - s \cdot s^T$ 。
4.2 结构:对角矩阵加秩为1的矩阵
雅可比矩阵是对角矩阵加秩为1的修正项,这使反向传播能高效计算。
4.3 大小的重要性
雅可比矩阵形状为 `n x n` ,在典型Transformer模型中可能非常大,需避免完全实例化。
5. 反向传播
在神经网络反向传播中,根据链式法则$\frac{dL}{dx} = J^T \cdot \frac{dL}{ds}$ 。可将其展开为$\frac{dL}{dx_i} = s_i\left(\frac{dL}{ds_i} - \sum_{j} s_j \cdot \frac{dL}{ds_j}\right)$ ,代码实现如下:
def softmax_backward(dL_ds, s): # s 是前向传播中缓存的Softmax输出 dot = np.sum(dL_ds * s) return s * (dL_ds - dot)这样无需实例化雅可比矩阵,避免大量存储。
6. 与交叉熵损失的联系
Softmax函数常与交叉熵损失一起使用,二者组合为反向传播提供简单表达式$\frac{dL}{dx} = s - y$ ,其中 `s` 是Softmax输出,`y` 是one - hot编码的真实标签。
7. 批量维度 - 轴的重要性
实际中常处理批量数据,`axis` 参数很重要。如形状为 `(batch_size, n_classes)` 的预测批量,在NumPy中用 `axis = 1` 对每个样本的类别独立应用Softmax。`keepdims = True` 确保广播正确。在注意力机制中,可能会用到 `axis = 2` 或 `axis = 3` 。
7.1 温度缩放
Softmax函数会集中概率分布,可通过引入温度参数控制集中程度。
