CANN-昇腾NPU-Speculative-Decoding-昇腾NPU上怎么用小模型加速大模型推理
Speculative Decoding 用小模型快速生成候选 token,大模型并行验证,正确的保留、错误的重新生成。在昇腾NPU上这个方法有天然优势——NPU 的 batch GEMM 对验证阶段的多 token 并行计算很高效。
原理
1. Draft Model(小模型)自回归生成 K 个候选 token 2. Target Model(大模型)一次 forward 验证 K 个 token 3. 找到第一个错误的 token,保留之前正确的 4. 从错误位置重新开始 例子(K=4): Draft 生成:A B C D Target 验证:A ✓ B ✓ C ✗ D ✗ 接受 A B,从 C 开始重新生成关键:Target Model 的验证是并行的——一次 forward 处理 K 个 token,比自回归快 K 倍。但只有正确率够高(> 60%),总体才比自回归快。
昇腾NPU上的实现
fromatbimportLLM,SpeculativeConfig# Draft Model: Llama2-7Bdraft_model=LLM("meta-llama/Llama-2-7b-hf",device="npu:0")# Target Model: Llama2-70B, 8 卡 TPtarget_model=LLM("meta-llama/Llama-2-70b-hf",device="npu:0,1,2,3,4,5,6,7",tensor_parallel_size=8,speculative_config=SpeculativeConfig(draft_model=draft_model,num_speculative_tokens=4,# 每次猜 4 个 token))output=target_model.generate("Hello",max_new_tokens=100)ATB 内部自动编排 draft 和 target 的交替执行。
为什么昇腾NPU适合 Speculative Decoding
Target Model 验证 K 个 token 时,等效 batch=K 的 prefill。Atlas 800I A2 上 batch=4 的 GEMM 利用率约 25%,而 batch=1(decode)只有 7%。
自回归:每步 batch=1,GEMM 利用率 7% Speculative:每步 batch=4,GEMM 利用率 25% 验证速度提升 25%/7% ≈ 3.5×NPU 在大 batch 下更高效,Speculative Decoding 正好把单 token decode 变成了多 token prefill。
接受率和加速比
加速比取决于 draft model 的接受率。接受率 = draft 生成正确 token 的比例。
| Draft 接受率 | K=4 加速比 | K=8 加速比 |
|---|---|---|
| 90% | 2.8× | 4.2× |
| 80% | 2.2× | 3.0× |
| 70% | 1.7× | 2.1× |
| 60% | 1.3× | 1.4× |
接受率低于 60% 时加速不明显,draft 的开销开始抵消收益。
如何提高接受率
方法 1:用同架构的小模型。Llama2-7B 做 Llama2-70B 的 draft model 比用不同架构的小模型接受率高 10-15%。因为同架构模型的输出分布更接近。
方法 2:增加 Draft Model 的温度。Draft Model 用略高的 Temperature(比如 1.1)生成,让候选更多样化,覆盖 Target Model 可能选择的 token。
方法 3:动态 K 值。不固定 K=4,根据最近几步的接受率动态调整。接受率高时增大 K,低时减小。
显存开销
Draft Model 的权重也要放在 NPU 显存里。Llama2-7B 作为 draft model 需要额外 14GB。
8 卡 Atlas 800I A2 × 64GB = 512GB 总显存:
- Target Model(70B):140GB
- Draft Model(7B):14GB
- KV Cache + buffer:剩余空间
512 - 140 - 14 = 358GB 给 KV Cache。如果不做 Speculative Decoding,504GB 给 KV Cache。显存少了 29%,但吞吐可能提升 2-3×。
Speculative Decoding 在昇腾NPU上的收益特别明显——把低利用率的 decode 变成高利用率的 batch prefill。前提是 draft model 的接受率 > 70%。同架构小模型 + 动态 K 值是最佳实践。仓库在这里:
https://atomgit.com/cann/ATB
