MQA:全部 Query 共享一套 Key-Value
本文基于昇腾CANN和昇腾NPU,围绕 ops-transformer 仓库的相关技术展开。
MQA(Multi-Query Attention)走到 GQA 的极端——所有 Query Head 共享同一组 K、V。8 个 Head 还是 32 个 Head,都只存一份。这对 KV Cache 的压力最小,代价是 Attention 表达能力下降。但推理任务里,这个 trade-off 往往划算。
MQA 的 KV Cache 省了多少
# MQA——一个 KV Head,全部 Query 复用defmqa_vs_mha_kv_model():""" 看不同模型尺寸的 KV Cache 差异 """configs={"llama-7b":{"layers":32,"heads":32,"dim":4096},"llama-13b":{"layers":40,"heads":40,"dim":5120},"llama-70b":{"layers":80,"heads":64,"dim":8192},}forname,cfginconfigs.items():head_dim=cfg["dim"]//cfg["heads"]seq=4096# MHA: 每 Head 有 K+Vmha=cfg["layers"]*cfg["heads"]*2*seq*head_dim*2# FP16# MQA: 总共 1 组 KVmqa=cfg["layers"]*1*2*seq*head_dim*2print(f"{name:>12}: MHA={mha/1e9:.1f}GB → MQA={mqa/1e9:.1f}GB"f" (省{mha/mqa:.0f}x)")| 模型 | MHA KV Cache | MQA KV Cache | 省 |
|---|---|---|---|
| LLaMA-7B | 3.2GB | 0.1GB | 32x |
| LLaMA-13B | 5.0GB | 0.1GB | 40x |
| LLaMA-70B | 20.0GB | 0.3GB | 64x |
70B 模型的显存省了 64 倍——从 20GB 降到 0.3GB。省出来的空间给更大的 Batch 或更长的 Context。
MQA 的计算流程
# MQA Attention——所有 Q 查同一份 K、Vimporttorchimporttorch.nn.functionalasFclassMQAAttention(torch.nn.Module):def__init__(self,hidden_dim,num_heads):super().__init__()self.num_heads=num_heads# 32self.head_dim=hidden_dim//num_heads# 128# Q 投影:跟 MHA 一样大self.q_proj=torch.nn.Linear(hidden_dim,num_heads*self.head_dim)# K、V 投影:只有 1 组self.k_proj=torch.nn.Linear(hidden_dim,self.head_dim)# 1 组!self.v_proj=torch.nn.Linear(hidden_dim,self.head_dim)# 1 组!defforward(self,x,past_kv=None):B,S,H=x.shape# Q 展开成 32 个 Headq=self.q_proj(x).reshape(B,S,self.num_heads,self.head_dim)# K、V 只有 1 组——shape: [B, S, 1, head_dim]k=self.k_proj(x).unsqueeze(2)# [B, S, 1, 128]v=self.v_proj(x).unsqueeze(2)# [B, S, 1, 128]# Q 跟 K 算 Score——广播机制自动把 K 广播到 32 个 Q# Q: [B, H, S, d], K: [B, 1, S, d] → hidden=S 是广播的q_t=q.transpose(1,2)# [B, 32, S, 128]k_t=k.transpose(1,2)# [B, 1, S, 128]# 广播 MatMul:32 个 Q 各自跟同一份 K 算 Scorescore=torch.matmul(q_t,k_t.transpose(-2,-1))# [B, 32, S, S]score=score/(self.head_dim**0.5)# 屏蔽 + Softmaxmask=torch.triu(torch.ones(S,S),diagonal=1).bool()score.masked_fill_(mask,float("-inf"))attn=F.softmax(score,dim=-1)# Attention 输出out=torch.matmul(attn,v.transpose(1,2))# V 也是广播的returnout.transpose(1,2).reshape(B,S,-1)广播 MatMul 是 PyTorch 层面自动做的,但在 NPU 上不能依赖自动广播——要手动安排 K、V 的 L1 复用。
CANN 上 MQA 的显存优化
// Ascend C 实现 MQA——K、V 只搬一次到 L1,32 个 Q 轮流算classMQAKernel:publicAscendC::Kernel{__aicore__inlinevoidProcess()override{// 只有 1 组 K、V——这是跟 GQA 唯一不同的地方constintnum_q_heads=32;constintnum_kv_heads=1;// MQA 的硬编码constintgroup_size=32;// 不是 4 了// K 和 V 只需加载 1 次AscendC::LocalTensor<float>k_local;AscendC::LocalAlloc(k_local,seq_len*head_dim);AscendC::DataCopy(k_local,gm_k,seq_len*head_dim);AscendC::LocalTensor<float>v_local;AscendC::LocalAlloc(v_local,seq_len*head_dim);AscendC::DataCopy(v_local,gm_v,seq_len*head_dim);// 32 个 Q 依次算——K、V 已在 L1,不需要重搬for(inth=0;h<num_q_heads;h++){AscendC::LocalTensor<float>q_local;AscendC::LocalAlloc(q_local,seq_len*head_dim);AscendC::DataCopy(q_local,gm_q+h*seq_len*head_dim,seq_len*head_dim);// Q @ K^T——K 已经在 L1 了AscendC::LocalTensor<float>score_local;AscendC::LocalAlloc(score_local,seq_len*seq_len);AscendC::MatMul(score_local,q_local,k_local,AscendC::CUBE_MATRIX_TYPE::TRANS_B);// Score @ V——V 也已经在 L1 了AscendC::LocalTensor<float>out_local;AscendC::LocalAlloc(out_local,seq_len*head_dim);AscendC::MatMul(out_local,score_local,v_local);// 写回——之前这段显存全给 KV Cache 了AscendC::DataCopy(gm_out+h*seq_len*head_dim,out_local,seq_len*head_dim);}// K、V 的 L1 空间在函数退出时自动释放// 64 个 Head 的搬运成本只付 1 次}};MQA 的设计哲学是:K、V 的多样性没那么重要。LLM 的 Self-Attention 里,Query 决定关注哪里,Key-Value 只提供上下文。多个 Head 共享 K、V 后精度损失远小于 KV Cache 减半的收益。实测 MQA 版 Llama 在推理时吞吐是 MHA 的 2.8 倍,精度差在 0.2% 以内。
参考仓库
MQA 等 Attention 变种算子
Transformer 加速库 ATB
