Select to Think:蒸馏 token 排序能力,效果平均提升24%
一句话总结
SLM 的效果差不是因为它不会,而是它没把正确答案排到第一;蒸馏排序能力让 SLM 效果平均提升 24%
- 论文标题:Select to Think: Unlocking SLM Potential with Local Sufficiency
- 论文地址:http://arxiv.org/pdf/2604.26940
- 作者背景:慕尼黑工业大学、慕尼黑大学、华为慕尼黑研究中心
一、动机:业务想用小模型,但效果总差一口气
凡是把大模型搬到线上服务过的同学,都遇到过同一个夹生的局面:
- 用大模型,QPS 一上来预算直接爆炸
- 换上 0.5B / 1.5B 这一档的小模型(SLM),便宜是便宜了、吞吐也起来了,但任务一难,质量肉眼可见地往下掉
举个直观例子:同一个 Qwen2.5-1.5B-Instruct,在 GSM8K 上能拿 72%,到了高中竞赛级的 MATH-500 就只剩 54%,再到 AIME 竞赛压轴题,几乎是 1.1% 的归零成绩。线上业务的痛苦是——只想用小模型省钱、又不想用户体验跌穿地板
围绕这个夹生地带,业界主流的补救方案包含两类
知识蒸馏
最直觉的思路是,让大模型当老师,构造大量样本去训练小模型。但这条路线未必能解决所有业务问题,因为蒸馏小模型的效果一般都会下降,如果是面对稍复杂一些的场景,可能再怎么蒸馏也达不到可用水准。本质问题是 SLM 的参数容量不够:让一个 1.5B 模型把十几万词的概率分布都拟合得跟 32B 老师一模一样,等于让小学生背完整本大学教材
协同推理
实际业务中,我们常常采用更折衷的协同推理方法:让 SLM 和 LLM 相互配合,LLM 只在 SLM 搞不定时才出手点拨两下(代表工作有 R2R、SpecReason、RelayLLM 等)
RelayLLM:token 级大小模型接力加速推理
这一思路看起来很完美:99% 的工作还是小模型扛,大模型只在 1% 的关键时刻露面,应该兼顾成本和效果,但实际部署时会撞上一个被严重低估的开销:LLM 每出手一次,并不只是算下一个 token 那么便宜,它必须把当前已经生成的整段上下文,从头到尾重新做一遍 prefill,然后才能给出建议
也就是说,表面上看像是"问一句答一句",实际上每次"点播"都要把前面几千 token 重读一遍——prefill 的算力开销才是大头。再加上每次同步调用引入的网络延迟和大模型常驻待命的成本,账算下来根本不划算
这两条路虽然实现路径不同,底层都默认了一个共同假设:小模型自己不会,所以一个想把整套答案灌给它,一个想在它需要时去外面找答案
S2T 这篇工作最打动人的地方在于,它跳出了上述思维定势,深入分析了小模型的输出分布,发现上述假设未必成立
二、关键观察:Local Sufficiency
作者做了很朴素的实验:把小模型推理解码时的 top-K 候选全列出来,然后看大模型最想选的那个 token,到底在不在其中。结果出乎意料:
| SLM 规模 | top-1 命中率 | top-8 命中率 |
|---|---|---|
| Qwen2.5-1.5B | 33% | 95% |
| Qwen2.5-0.5B | 较低 | 83% |
换 Gemma 系列、换不同任务,结论都一致
也就是说,小模型并不是不知道答案,而是没能区分正确答案与其他候选,greedy decoding 选 top-1 时恰好挑了个不那么好的,但正确答案就静静躺在它自己的 top-2 ~ top-8 里
作者把这个现象命名为Local Sufficiency(局部充分性):在关键决策点上,SLM 自己提的局部候选集,已经足够覆盖 LLM 的偏好
这一观察直接改变了游戏规则 —— 问题从 “开放生成” 退化成了 “K 选 1 的判别”。原本要让小模型从十几万个词里凭空造出正确答案,现在只要它能在自己提的 8 个候选里挑对就行,任务难度直接降低了一个量级
三、S2T 触发器
不是每个 token 都需要折腾。原理上用 KL 散度衡量 SLM 与 LLM 输出分布的分歧,超过 top-τ 阈值(默认 1%)就触发
但这里有个关键的工程细节:算 KL 需要同时知道 SLM 和 LLM 的分布,部署时不可能每步都真的去叫一次 LLM。作者的做法是:
- 训练阶段:用 LLM 一次性产出 KL 标签,作为 oracle 信号
- 部署阶段:把这个判断蒸馏成一个超轻量的 router head(两层 MLP),小模型只需要看一眼自己的隐藏状态就能判断这一步要不要触发,不引入任何额外的大模型调用
这个 router 在验证集上准确率 87.78%,是把"协同推理"真正做成"自闭环"的关键一环
四、S2T 评分器
触发后,让 SLM 先吐出 top-K 候选(默认 K=8),用一个评分函数给每个候选打分,选分数最高的作为最终输出
4.1 常规解法的问题
作者首先尝试了两种朴素的思路:
- point wise 评分
要实现一个打分组件,最直觉的做法就是让模型直接去预测每个候选的好坏,即使用线性头 + sigmoid 的 point wise 打分排序,但这一路线的问题在于:
- 类别极度不平衡:K=16 个候选里只有 1 个对,正负比 1:15。模型走捷径直接学成"全部预测为错"也能拿 93% 准确率
- 目标冲突:“生成下一个词” 和 “判断这个词好坏” 两个目标是冲突的,因为 “生成” 需要平滑分布(保留多样性)、“判别” 需要尖锐分布(决策果断),两个目标互相拉扯,结果都做不好
- 教师分布蒸馏
既然二分类判别不可取,另一种直观做法是从 LLM 教师分布中,挑出 SLM 的 top-16 token 的分布进行蒸馏学习。如此一来不仅避免了类别不平衡问题,还维持了 “下一个词分布” 的学习目标
但实测下来依然不可取,原因在于一开始就提到的 SLM 参数容量问题:让 SLM 模仿教师的输出 token 已经很难了,让它学习 token 分布岂不是更强人所难
4.2 通道隔离方案
作者借鉴了ZIP(Zero-overhead Inference-time Prediction)的思路:让模型在不动主任务的前提下,用边角料 logits 承担一个独立的子任务
LLM 的词表里本来就有一批保留 token,比如 Qwen2.5 中是 ID 151920–151935 这 16 个,平时被屏蔽,不参与生成。但模型在前向计算时,这些保留 token 的 logits 仍然会被算出来,只是被丢弃了
S2T-Local 的做法是把这部分 logits 重新解释为"打分专用通道":
- 给 16 个保留 token 预定义一组 bin 值(比如 0, 1/15, 2/15, …, 1)
- 这 16 个 logit 经 softmax 得到一个分布
- 用这个分布对 bin 值做加权和,得到一个 0~1 的评分
每个候选拼到当前前缀后做一次前向,取保留 token 的 logits 算分。K 个候选可以批处理一次性算完,所以选择本身的额外开销很小
这个设计的妙处在于评分通道与生成通道彻底解耦:保留 token 学打分,标准 token 继续做语言建模,互不干扰
4.3 训练数据
Selector 的训练集只有 2000 条,全部来自 MATH 训练集,构造流程很轻量:
- 让 SLM 自由地 roll out,在每一步算一下 SLM 与 LLM 输出分布的 KL 散度
- 只在 KL 最高的 top-10% 步骤里采样(这些是"分歧大、值得介入"的关键步骤)
- 在每个被选中的步骤上让 SLM 提 K=16 个候选,利用 LLM 生成每个候选的条件概率
- LLM 概率最高的那个候选标为 Golden label,这就是学生要学着挑出来的
除了 16 个保留 token 的 lm_head 嵌入以外,初始模型参数都被冻结,然后添加 rank=16 的 LoRA 模块。1.5B 模型上可训练参数 74.3M,只占不到6%;用 AdamW 优化器和 cosine 学习率,2 个 epoch 就收敛
4.4 训练目标
最终的损失函数是三项之和:L=L_select + L_margin + β·L_kl,分别代表:
- 选对
把 K 个候选的得分过 softmax(带温度 T=0.2,让分布更尖锐),鼓励 Golden label 的得分更高- L_select = 金标的 softmax 分数
- 选得果断
强制 Golden label 和最强的错误候选之间至少拉开一个边界 m,m 正比于 batch 内所有候选得分的标准差,意味着当教师也无法显著区分不同候选时(标准差小),这里的边界要求也小,避免不合理的惩罚- L_margin = ReLU ( m - ( 金标分数 - 最强错误候选分数 ) )
- 保持语言能力
由于 LoRA 模块加在了所有 attention/MLP 模块上,优化前两个 loss 容易污染标准词表的 logit 输出。所以还需要保证模型不要偏离原始 SLM 太多,避免语言能力被损坏- L_kl = KL ( 原始 SLM 分布 || 训练后 SLM 分布 )
五、实验结果
5.1 六个 benchmark 全面提升
在 0.5B / 1.5B 两个 SLM 规模、六个 benchmark 上的结果:
- 1.5B 平均提升 24.1%(相对 SLM Greedy)
- 0.5B 平均提升 38%,模型越小,从这套方法里捞到的好处反而越大
- 单条轨迹推理的 S2T-Local 追平了 8 路自一致采样(Maj@8),后者要付 8 倍的算力
值得强调的是:S2T-Local 只用 MATH 训练集训练,HumanEval / MMLU-Pro 等都是 OOD(域外)评测,泛化能力很扎实
5.2 Local Sufficiency 验证
为了验证本文提出的 “Local Sufficiency” 假设的正确性,作者进行了 top-k 命中率测试与任务准确性测试。前者表示 SLM 的输出分布中,不同 top-k 下命中正确 token 的比例;后者表示通过上述触发器与选择器进行解码后,最终任务上的准确率情况
| K | Hit Rate | 准确率 |
|---|---|---|
| 1(即 greedy) | 33% | 54.1% |
| 8 | 97% | 81.1% |
| 16 | 99% | 80.0% |
| ∞ | 99% | ~80% |
K=8 是性价比甜点,再扩大候选集对正确率几乎无提升
5.3 推理效率
虽然每次触发要做 K 路前向打分,但因为触发只发生在 1% 的 token 上,整体延迟基本和原始 SLM Greedy 持平。AIME25 上的实测:
| 方法 | 推理时间 | 相对 Greedy |
|---|---|---|
| SLM Greedy | 9.3s | 1× |
| S2T-Local | 21.0s | 2.3× |
| LLM-Takeover(让 LLM 在触发点直接生成) | 81.5s | 8.8× |
S2T-Local 相对 LLM-Takeover节省约 75% 时间,同时把 accuracy 拉了上去。这正是它解决"协同推理 prefill 大头"问题的关键:所有 K 路评分都在 SLM 上完成,可以批处理一次性算完,根本不用同步调用大模型
六、Selector 学到了什么
蒸馏出来的内在评委到底学到了什么?作者做了几个分析
6.1 Selector 准确性
用 Agree@1 来衡量,即看 selector 与 LLM 的 top1 选择是否相同
| 数据集 | Agree@1 | 随机基线 |
|---|---|---|
| MATH500 | 69.4% | 12.5% |
| OlympiadBench | 71.4% | 12.5% |
| MMLU-Pro | 67.2% | 12.5% |
随机猜(K=8)只有 12.5%,Selector 能稳定拿到 ~70%,意味着它确实把 LLM 的偏好规律学到了
除了看排序效果,作者确认了 selector 具体打分的准确性:那些被 SLM 严重低估的 token(SLM prob ≈ 25-30%,但 LLM prob ≈ 60-70%),Selector 会把它们的得分主动拉到 80-90%。这说明 Selector 不是简单复读 SLM 的偏好,而是主动校正了 SLM 的判断失误
6.3 Winner-takes-all:决策边界很锐利
小提琴图显示了一个清晰的赢者通吃特性:
- Rank 1 候选的中位分接近 0.9
- Rank 2 直接掉到接近 0
决策边界非常锐利,绝不拖泥带水。这种锐利的判别边界正好对应了"让小模型决断更果敢"的目标——这也是用 top-1 离散交叉熵训练(而不是分布匹配)能拿到的好处
七、核心启发
- 蒸馏目标的视角切换
S2T 最有价值的地方不是又刷了一波榜,而是它跳出了思维定势对蒸馏目标做了重新定义
| 范式 | 蒸馏的目标 | 信号复杂度 |
|---|---|---|
| 传统词表蒸馏 | 词表级分布模拟 | 极高(十几万维) |
| S2T | 对已有候选打分 | 低(K=8 维) |
这个视角同时也提示我们一个更具普遍意义的洞察:
很多时候 SLM 的瓶颈不是"知识储备不够",而是"决策校准没做好"——它知道答案在 top-8 里某一个位置,只是没能力解开与其他候选的纠缠
如果这是普遍现象,那么"在解码时配一个轻量级的内在评委"可能比"扩参数 / 蒸馏完整分布"性价比高得多,这对追求"低成本高吞吐"的业务部署有非常直接的指导价值
- 对保留 token 的利用
ZIP 这种借保留 token 的 logits 做副任务的范式,可以在不破坏主任务的前提下让模型多承担一个子任务,未来也许能扩展到更多场景:
- 置信度估计
- 毒性 / 安全性检测
- 情感预测
- 结构化输出的 schema 校验
只要任务能编码成 0~1 的分值或一组离散 bin,几乎不增加推理开销就能塞进这套机制
