LLM 推理加速:从算子融合到投机解码的工程实践
LLM 推理加速:从算子融合到投机解码的工程实践
一、延迟瓶颈:内存带宽而非算力
大模型推理的延迟主要卡在四个环节:数据搬运(权重从 HBM 加载)、计算(矩阵乘和注意力)、KV Cache 管理(历史 Token 读写)以及调度开销(请求排队)。实际部署中,真正的瓶颈往往是内存带宽,而非计算能力。
以 A100-80G 为例,其 FP16 峰值算力达 312 TFLOPS,但 HBM 带宽仅为 2TB/s。一个 7B 模型单次前向传播,计算量约 14GFLOP(耗时 0.045ms),但读取权重需 14GB(耗时约 7ms)。数据搬运耗时是计算的 155 倍。这就是典型的“内存墙”:推理性能受限于带宽,算力大部分时间在空转等待数据。
加速的核心思路很直接:减少内存访问(算子融合、KV Cache 优化)、提高计算密度(连续批处理、投机解码)、降低精度(量化)。具体选哪种,得看业务对延迟和精度的容忍度。
二、技术栈分层
flowchart TB subgraph 模型层优化 Q[模型量化: FP16→INT8/INT4] --> Q1[显存减少50-75%] Q --> Q2[带宽需求降低] GQA[GQA/MQA: 共享KV头] --> GQA1[KV Cache减少4-8x] end subgraph 算子层优化 FUSE[算子融合: Flash Attention] --> FUSE1[减少HBM访问次数] FUSE --> FUSE2[单次前向: 7ms→2ms] KV[KV Cache分页: PagedAttention] --> KV1[显存利用率95%+] end subgraph 调度层优化 CB[连续批处理: Continuous Batching] --> CB1[吞吐量提升2-3x] SD[投机解码: Speculative Decoding] --> SD1[延迟降低30-50%] PD[前缀缓存: Prefix Caching] --> PD1[重复Prompt零计算] end subgraph 系统层优化 CB1 --> THROUGHPUT[吞吐量优化] SD1 --> LATENCY[延迟优化] Q2 --> COST[成本优化] end style FUSE fill:#e3f2fd style CB fill:#fff3e0 style SD fill:#e8f5e9 style Q fill:#fce4ec优化通常按模型、算子、调度、系统四个层面展开。模型层解决显存和带宽(量化、GQA);算子层解决计算效率(Flash Attention、PagedAttention);调度层解决并发(连续批处理、投机解码);系统层解决资源复用(前缀缓存)。各层优化可独立生效,组合使用效果更明显。
三、核心工程实现
3.1 连续批处理(Continuous Batching)
传统静态批处理必须等所有请求生成完毕才能释放显存,而连续批处理在每个迭代步动态调整批次:完成的请求立即移出,新请求立即加入。
# continuous_batching.py — 连续批处理调度器 import time from dataclasses import dataclass, field from typing import Optional from collections import deque @dataclass class InferenceRequest: """推理请求""" request_id: str prompt_tokens: list[int] max_output_tokens: int = 256 temperature: float = 0.7 # 运行时状态 generated_tokens: list[int] = field(default_factory=list) is_completed: bool = False arrival_time: float = field(default_factory=time.time) first_token_time: Optional[float] = None class ContinuousBatcher: """连续批处理调度器""" def __init__(self, max_batch_size: int = 32, max_waiting_queue: int = 1000, scheduling_policy: str = "fcfs"): self._max_batch_size = max_batch_size self._max_waiting_queue = max_waiting_queue self._scheduling_policy = scheduling_policy self._waiting_queue: deque[InferenceRequest] = deque() self._running_batch: list[InferenceRequest] = [] self._completed_requests: list[InferenceRequest] = [] def submit(self, request: InferenceRequest) -> bool: """提交推理请求""" if len(self._waiting_queue) >= self._max_waiting_queue: return False self._waiting_queue.append(request) return True def step(self, model_step_fn) -> list[InferenceRequest]: """执行一个推理步骤""" # 1. 移除已完成的请求 completed = [req for req in self._running_batch if req.is_completed] self._running_batch = [req for req in self._running_batch if not req.is_completed] self._completed_requests.extend(completed) # 2. 补充新请求到批次 available_slots = self._max_batch_size - len(self._running_batch) while available_slots > 0 and self._waiting_queue: if self._scheduling_policy == "fcfs": request = self._waiting_queue.popleft() elif self._scheduling_policy == "sjf": shortest = min(self._waiting_queue, key=lambda r: r.max_output_tokens) self._waiting_queue.remove(shortest) request = shortest else: request = self._waiting_queue.popleft() self._running_batch.append(request) available_slots -= 1 # 3. 执行前向传播 if self._running_batch: model_step_fn(self._running_batch) for req in self._running_batch: if req.first_token_time is None: req.first_token_time = time.time() if len(req.generated_tokens) >= req.max_output_tokens: req.is_completed = True return completed def get_stats(self) -> dict: """获取调度器统计信息""" return { "waiting_queue_size": len(self._waiting_queue), "running_batch_size": len(self._running_batch), "completed_count": len(self._completed_requests), "utilization": round(len(self._running_batch) / self._max_batch_size, 2) if self._max_batch_size > 0 else 0, }3.2 投机解码(Speculative Decoding)
用小模型(Draft Model)快速生成 K 个候选 Token,大模型(Target Model)一次性验证。只有被大模型接受的 Token 才计入最终结果。
# speculative_decoding.py — 投机解码实现 import time from dataclasses import dataclass from typing import Optional @dataclass class SpeculativeConfig: """投机解码配置""" draft_model_name: str = "qwen2-0.5b" target_model_name: str = "qwen2-7b" speculative_length: int = 5 temperature: float = 0.7 class SpeculativeDecoder: """投机解码器 加速比 = 1 / (1 - 接受率) 当接受率为 80% 时,理论加速比约 2.5x """ def __init__(self, draft_model_fn=None, target_model_fn=None, config: SpeculativeConfig = None): self._draft_fn = draft_model_fn self._target_fn = target_model_fn self._config = config or SpeculativeConfig() self._accept_stats = { "total_tokens": 0, "accepted_tokens": 0, } def generate(self, prompt_tokens: list[int], max_tokens: int = 256) -> dict: """执行投机解码生成""" generated = [] total_draft_tokens = 0 total_accepted = 0 total_target_calls = 0 while len(generated) < max_tokens: # Step 1: 草稿模型快速生成 K 个候选 Token draft_tokens = self._draft_generate( prompt_tokens + generated, self._config.speculative_length, ) total_draft_tokens += len(draft_tokens) # Step 2: 目标模型一次性验证 K+1 个位置 verify_result = self._target_verify( prompt_tokens + generated, draft_tokens, ) total_target_calls += 1 # Step 3: 处理验证结果 accepted_count = verify_result["accepted_count"] total_accepted += accepted_count generated.extend(draft_tokens[:accepted_count]) # 从拒绝点采样或补充 bonus token if accepted_count < len(draft_tokens): corrected_token = verify_result.get("corrected_token") if corrected_token is not None: generated.append(corrected_token) else: bonus_token = verify_result.get("bonus_token") if bonus_token is not None: generated.append(bonus_token) generated = generated[:max_tokens] self._accept_stats["total_tokens"] += total_draft_tokens self._accept_stats["accepted_tokens"] += total_accepted accept_rate = (total_accepted / total_draft_tokens if total_draft_tokens > 0 else 0) return { "generated_tokens": len(generated), "total_draft_tokens": total_draft_tokens, "accepted_tokens": total_accepted, "accept_rate": round(accept_rate, 4), "target_model_calls": total_target_calls, "speedup_estimate": round(1 / (1 - accept_rate + 0.1), 2), } def _draft_generate(self, context: list[int], num_tokens: int) -> list[int]: """草稿模型生成候选 Token""" if self._draft_fn: return self._draft_fn(context, num_tokens) return list(range(100, 100 + num_tokens)) def _target_verify(self, context: list[int], draft_tokens: list[int]) -> dict: """目标模型验证候选 Token""" if self._target_fn: return self._target_fn(context, draft_tokens) import random accepted = 0 for i in range(len(draft_tokens)): if random.random() < 0.8: accepted += 1 else: break return { "accepted_count": accepted, "corrected_token": 200 if accepted < len(draft_tokens) else None, "bonus_token": 300 if accepted == len(draft_tokens) else None, } def get_accept_rate(self) -> float: """获取历史平均接受率""" total = self._accept_stats["total_tokens"] accepted = self._accept_stats["accepted_tokens"] return round(accepted / total, 4) if total > 0 else 0四、精度代价与适用边界
量化:INT8 对 7B 模型精度影响通常在 0.5% 以内,INT4 则在 1%-3%。对话生成等场景对 INT4 容忍度较高;代码生成、数学推理等强逻辑任务,建议保留 INT8 或 FP8。
投机解码:加速效果完全取决于草稿模型的接受率。如果接受率低于 60%,验证开销会抵消生成收益,反而变慢。同系列模型(如 Qwen2-0.5B 配 Qwen2-7B)输出分布接近,接受率通常在 75%-85%,效果最稳。
连续批处理:吞吐量上去了,但尾部延迟可能增加。短请求若和长请求混批,得等长请求跑完才能释放显存。解决办法是引入优先级调度,或者按延迟要求分批次处理。
前缀缓存:缓存系统提示词等重复 Prompt 的 KV Cache 能省计算,但会占显存。如果命中率低,反而浪费资源。建议只缓存高频前缀,并配上 LRU 淘汰策略。
五、总结
LLM 推理加速是全栈工程,模型、算子、调度、系统四层都有优化空间。从投入产出比看,Flash Attention 和连续批处理最值得优先落地。投机解码在“大小模型搭配”场景下效果明显,但得先测接受率。量化是降低成本的直接手段,INT8 风险低,INT4 需评估业务容忍度。
建议从 Flash Attention + 连续批处理入手,结合 pprof 数据决定是否引入投机解码和量化。每次优化后务必做基准测试,用数据说话。
