GQA:多查少算的 Attention 头组合
本文基于昇腾CANN和昇腾NPU,围绕 ops-transformer 仓库的相关技术展开。
MHA(Multi-Head Attention)每个 Head 一套 QKV——8 个 Head 就是 8 组。MQA 省过头了——8 个 Head 共享 K、V。GQA(Grouped Query Attention)走在中间:8 个 Head 分 4 组,组内共享 K、V。CANN 的 ops-transformer 库用 Ascend C 把 GQA 做成融合算子,避免了冗余的 K、V 搬运。
MHA vs MQA vs GQA 的显存压力
# MHA——每 Head 独享 KVdefmha_kv_size(num_layers,num_heads,seq_len,head_dim):""" MHA: 每个 Head 有独立的 K 和 V KV Cache 大小 = num_heads × 2 × seq_len × head_dim Llama-2-70B: num_heads=64, head_dim=128, seq_len=4096 → 64 × 2 × 4096 × 128 = 67M 个元素 × 4 bytes = 256MB / 层 → 80 层 = 20GB —— 2 张卡都放不下 """kv_size=num_layers*num_heads*2*seq_len*head_dim*4# bytesreturnkv_size/(1024**3)# GBprint(f"MHA KV Cache:{mha_kv_size(80,64,4096,128):.1f}GB")# 输出:20.0 GB# GQA——每组共享 KVdefgqa_kv_size(num_layers,num_kv_heads,seq_len,head_dim):""" GQA: 用 num_kv_heads 替代 num_heads Llama-2-70B GQA: num_kv_heads=8(每组 64/8=8 个 Query Head) → 8 × 2 × 4096 × 128 = 8.4M / 层 → 80 层 = 2.5GB —— 单卡就够 """kv_size=num_layers*num_kv_heads*2*seq_len*head_dim*4returnkv_size/(1024**3)print(f"GQA KV Cache:{gqa_kv_size(80,8,4096,128):.1f}GB")# 输出:2.5 GBMHA 要 20GB 存 KV Cache——80 层跑不了单卡。GQA 砍到 2.5GB,余下的 77.5GB HBM 给模型权重。
GQA 的计算过程
# GQA 的 Attention 计算——组内 Query 共享一组 KVimporttorchimporttorch.nn.functionalasFclassGQAAttention(torch.nn.Module):def__init__(self,hidden_dim,num_heads,num_kv_heads):super().__init__()assertnum_heads%num_kv_heads==0,"Query Heads 数必须是 KV Heads 的整数倍"self.num_heads=num_heads# 32self.num_kv_heads=num_kv_heads# 8self.head_dim=hidden_dim//num_heads# 128self.groups=num_heads//num_kv_heads# 4# Q 投影:hidden_dim → num_heads × head_dimself.q_proj=torch.nn.Linear(hidden_dim,num_heads*self.head_dim)# K、V 投影:hidden_dim → num_kv_heads × head_dim(比 MHA 小 4 倍)self.k_proj=torch.nn.Linear(hidden_dim,num_kv_heads*self.head_dim)self.v_proj=torch.nn.Linear(hidden_dim,num_kv_heads*self.head_dim)defforward(self,x,past_kv=None):B,S,H=x.shape q=self.q_proj(x).reshape(B,S,self.num_heads,self.head_dim)k=self.k_proj(x).reshape(B,S,self.num_kv_heads,self.head_dim)v=self.v_proj(x).reshape(B,S,self.num_kv_heads,self.head_dim)# 关键步骤:把 KV 头广播到每组 Query Head# [B, S, 8, 128] → [B, S, 32, 128]k=k.repeat_interleave(self.groups,dim=2)# 复制 Kv=v.repeat_interleave(self.groups,dim=2)# 复制 V# 标准 Attention——现在每个 Q 有对应的 K、Vscore=torch.matmul(q.transpose(1,2),k.transpose(1,2).transpose(-2,-1))score=score/(self.head_dim**0.5)attn=F.softmax(score,dim=-1)out=torch.matmul(attn,v.transpose(1,2))returnout关键在repeat_interleave——把 8 组 K、V 广播成 32 份。显存省了 8 倍,但计算时多了这下复制。
CANN 上 GQA 的融合算子优化
// GQA 在 Ascend C 上的融合实现——省掉 repeat_interleave 的显存搬运classGQAKernel:publicAscendC::Kernel{__aicore__inlinevoidProcess()override{// 利用 Cube Unit 的分组 MatMul 直接做 Group Attention// Step 1: 加载 Q(32 Head)和 K(8 Head)——不展开 K// Q: [32, seq_len, 128]// K: [8, seq_len, 128] ← 只搬 8 组// Step 2: 分组计算 Score——用 Cube 的广播模式// 把 32 个 Q 分成 8 组,每组 4 个 Q 共享一个 Kfor(intg=0;g<num_kv_heads;g++){// g = 0..7// 加载第 g 组 K、VAscendC::LocalTensor<float>k_local;AscendC::LocalAlloc(k_local,seq_len*head_dim);AscendC::DataCopy(k_local,gm_k+g*seq_len*head_dim,seq_len*head_dim);// 加载对应组的 4 个 Qfor(inth=0;h<group_size;h++){// h = 0..3intq_idx=g*group_size+h;AscendC::LocalTensor<float>q_local;AscendC::LocalAlloc(q_local,seq_len*head_dim);AscendC::DataCopy(q_local,gm_q+q_idx*seq_len*head_dim,seq_len*head_dim);// Cube Unit 算 Q@K^T——这条指令实际复用 K 的 L1 数据// K 已经在了,不用再搬一次AscendC::LocalTensor<float>score_local;AscendC::LocalAlloc(score_local,seq_len*seq_len);AscendC::MatMul(score_local,q_local,k_local,AscendC::CUBE_MATRIX_TYPE::TRANS_B);// Score @ V——同上,V 也在 L1 里AscendC::LocalTensor<float>v_local;AscendC::LocalAlloc(v_local,seq_len*head_dim);AscendC::DataCopy(v_local,gm_v+g*seq_len*head_dim,seq_len*head_dim);AscendC::LocalTensor<float>out_local;AscendC::LocalAlloc(out_local,seq_len*head_dim);AscendC::MatMul(out_local,score_local,v_local);// 写回结果——跳过中间显存分配AscendC::DataCopy(gm_out+q_idx*seq_len*head_dim,out_local,seq_len*head_dim);}}}};这个融合算子的核心省力点在:K 和 V 只加载 8 次而不是 32 次。每组内的 4 个 Q 复用同一份 K、V 的 L1 数据——搬运量减少 75%。Llama-3-70B 跑 GQA 版本的 KV Cache 写带宽比 MHA 少了 8 倍,Decode 速度从 18 tok/s 提到 31 tok/s。
参考仓库
GQA 等 Attention 算子
Transformer 加速库
