当前位置: 首页 > news >正文

MQA:全部 Query 共享一套 Key-Value

本文基于昇腾CANN和昇腾NPU,围绕 ops-transformer 仓库的相关技术展开。

MQA(Multi-Query Attention)走到 GQA 的极端——所有 Query Head 共享同一组 K、V。8 个 Head 还是 32 个 Head,都只存一份。这对 KV Cache 的压力最小,代价是 Attention 表达能力下降。但推理任务里,这个 trade-off 往往划算。

MQA 的 KV Cache 省了多少

# MQA——一个 KV Head,全部 Query 复用defmqa_vs_mha_kv_model():""" 看不同模型尺寸的 KV Cache 差异 """configs={"llama-7b":{"layers":32,"heads":32,"dim":4096},"llama-13b":{"layers":40,"heads":40,"dim":5120},"llama-70b":{"layers":80,"heads":64,"dim":8192},}forname,cfginconfigs.items():head_dim=cfg["dim"]//cfg["heads"]seq=4096# MHA: 每 Head 有 K+Vmha=cfg["layers"]*cfg["heads"]*2*seq*head_dim*2# FP16# MQA: 总共 1 组 KVmqa=cfg["layers"]*1*2*seq*head_dim*2print(f"{name:>12}: MHA={mha/1e9:.1f}GB → MQA={mqa/1e9:.1f}GB"f" (省{mha/mqa:.0f}x)")
模型MHA KV CacheMQA KV Cache
LLaMA-7B3.2GB0.1GB32x
LLaMA-13B5.0GB0.1GB40x
LLaMA-70B20.0GB0.3GB64x

70B 模型的显存省了 64 倍——从 20GB 降到 0.3GB。省出来的空间给更大的 Batch 或更长的 Context。

MQA 的计算流程

# MQA Attention——所有 Q 查同一份 K、Vimporttorchimporttorch.nn.functionalasFclassMQAAttention(torch.nn.Module):def__init__(self,hidden_dim,num_heads):super().__init__()self.num_heads=num_heads# 32self.head_dim=hidden_dim//num_heads# 128# Q 投影:跟 MHA 一样大self.q_proj=torch.nn.Linear(hidden_dim,num_heads*self.head_dim)# K、V 投影:只有 1 组self.k_proj=torch.nn.Linear(hidden_dim,self.head_dim)# 1 组!self.v_proj=torch.nn.Linear(hidden_dim,self.head_dim)# 1 组!defforward(self,x,past_kv=None):B,S,H=x.shape# Q 展开成 32 个 Headq=self.q_proj(x).reshape(B,S,self.num_heads,self.head_dim)# K、V 只有 1 组——shape: [B, S, 1, head_dim]k=self.k_proj(x).unsqueeze(2)# [B, S, 1, 128]v=self.v_proj(x).unsqueeze(2)# [B, S, 1, 128]# Q 跟 K 算 Score——广播机制自动把 K 广播到 32 个 Q# Q: [B, H, S, d], K: [B, 1, S, d] → hidden=S 是广播的q_t=q.transpose(1,2)# [B, 32, S, 128]k_t=k.transpose(1,2)# [B, 1, S, 128]# 广播 MatMul:32 个 Q 各自跟同一份 K 算 Scorescore=torch.matmul(q_t,k_t.transpose(-2,-1))# [B, 32, S, S]score=score/(self.head_dim**0.5)# 屏蔽 + Softmaxmask=torch.triu(torch.ones(S,S),diagonal=1).bool()score.masked_fill_(mask,float("-inf"))attn=F.softmax(score,dim=-1)# Attention 输出out=torch.matmul(attn,v.transpose(1,2))# V 也是广播的returnout.transpose(1,2).reshape(B,S,-1)

广播 MatMul 是 PyTorch 层面自动做的,但在 NPU 上不能依赖自动广播——要手动安排 K、V 的 L1 复用。

CANN 上 MQA 的显存优化

// Ascend C 实现 MQA——K、V 只搬一次到 L1,32 个 Q 轮流算classMQAKernel:publicAscendC::Kernel{__aicore__inlinevoidProcess()override{// 只有 1 组 K、V——这是跟 GQA 唯一不同的地方constintnum_q_heads=32;constintnum_kv_heads=1;// MQA 的硬编码constintgroup_size=32;// 不是 4 了// K 和 V 只需加载 1 次AscendC::LocalTensor<float>k_local;AscendC::LocalAlloc(k_local,seq_len*head_dim);AscendC::DataCopy(k_local,gm_k,seq_len*head_dim);AscendC::LocalTensor<float>v_local;AscendC::LocalAlloc(v_local,seq_len*head_dim);AscendC::DataCopy(v_local,gm_v,seq_len*head_dim);// 32 个 Q 依次算——K、V 已在 L1,不需要重搬for(inth=0;h<num_q_heads;h++){AscendC::LocalTensor<float>q_local;AscendC::LocalAlloc(q_local,seq_len*head_dim);AscendC::DataCopy(q_local,gm_q+h*seq_len*head_dim,seq_len*head_dim);// Q @ K^T——K 已经在 L1 了AscendC::LocalTensor<float>score_local;AscendC::LocalAlloc(score_local,seq_len*seq_len);AscendC::MatMul(score_local,q_local,k_local,AscendC::CUBE_MATRIX_TYPE::TRANS_B);// Score @ V——V 也已经在 L1 了AscendC::LocalTensor<float>out_local;AscendC::LocalAlloc(out_local,seq_len*head_dim);AscendC::MatMul(out_local,score_local,v_local);// 写回——之前这段显存全给 KV Cache 了AscendC::DataCopy(gm_out+h*seq_len*head_dim,out_local,seq_len*head_dim);}// K、V 的 L1 空间在函数退出时自动释放// 64 个 Head 的搬运成本只付 1 次}};

MQA 的设计哲学是:K、V 的多样性没那么重要。LLM 的 Self-Attention 里,Query 决定关注哪里,Key-Value 只提供上下文。多个 Head 共享 K、V 后精度损失远小于 KV Cache 减半的收益。实测 MQA 版 Llama 在推理时吞吐是 MHA 的 2.8 倍,精度差在 0.2% 以内。

参考仓库

MQA 等 Attention 变种算子

Transformer 加速库 ATB

http://www.jsqmd.com/news/872983/

相关文章:

  • 2026数字人平台十大推荐:按预算分级企业选型避坑攻略
  • 资产治理:QNAP 存算融合架构理顺工程机械装配车间异构图纸流转
  • 泸州6月雨季来临,房屋漏水怎么办?卫生间免砸砖防水、外墙、屋面+地下室渗漏。权威防水公司靠谱TOP5推荐(2026年6月本地最新深度调研) - 企业资讯
  • 5分钟掌握Illustrator批量替换终极技巧:ReplaceItems.jsx完整指南
  • 广西贵港CPPMSCMP官网报考入口,官方授权双证报考中心 - 众智商学院课程中心
  • 终极指南:如何用TrollInstallerX轻松解锁iOS越狱新世界
  • Taotoken 的 Token Plan 套餐如何帮助我们预测并锁定开发成本
  • 从0到1搭建智能健身助手,深度解析LLM+多模态传感器融合架构,含可商用API接口设计
  • LoRA 部署:微调后的模型怎么上线
  • 3种实战方法搞定Docker镜像加速:从零到精通完全指南
  • CANN ATC模型编译器深度解析:ONNX到OM的编译全流程与黑盒参数详解
  • 从开题到定稿零返工:okbiye 毕业论文 AI 写作,把格式和内容难题都解决了
  • 通过Taotoken Token Plan套餐降低长期项目成本的观察
  • 【行业首发】Midjourney v6.2水动力学渲染白皮书:基于流体物理模型的prompt工程重构(附NASA水波频谱对照表)
  • 暂时停止所有开发工作------全部转到销售+推广
  • 回收福禄克Fluke 5730A多功能校准器
  • 5款必备Illustrator脚本:让你的设计效率提升300%
  • 股票低开必读:5条黄金口诀,教你反手掌握主动权
  • QLoRA:4-bit 量化微调的完整链路
  • vLLM 在 CANN 上的推理优化
  • 防城港6月雨季来临,房屋漏水怎么办?卫生间免砸砖防水、外墙、屋面+地下室渗漏。权威防水公司靠谱TOP5推荐(2026年6月本地最新深度调研) - 企业资讯
  • AI Agent不是替代工程师,而是重建协作范式:建筑全生命周期8类角色能力升级路线图(限时公开)
  • 别只看页面:盲盒源码小程序V6MAX系统与盲盒app源码程序解析 - 壹软科技
  • 使用OpenClaw连接Taotoken配置Agent工作流的具体步骤
  • RimSort终极指南:3步解决环世界MOD加载顺序混乱的完整方案
  • Lindy流程自动化效果衰减真相:3年追踪数据显示,未做持续治理的企业6个月后效率回落至基线112%
  • DeepSeek-R1 在 CANN 上的推理部署
  • 钦州6月雨季来临,房屋漏水怎么办?卫生间免砸砖防水、外墙、屋面+地下室渗漏。权威防水公司靠谱TOP5推荐(2026年6月本地最新深度调研) - 企业资讯
  • 最新论文降重工具横向测评|新手零踩雷选择指南
  • 如何轻松实现Windows任务栏图标居中?TaskbarX完整使用指南