当前位置: 首页 > news >正文

注意力机制:多头注意力机制、分组查询注意力机制、多查询注意力机制理论+代码

文章目录

    • 导语
    • 1.注意力机制
    • 2.多头注意力机制
    • 3.多查询注意力机制
    • 4.分组查询注意力机制
    • 5.三者对比

导语

注意力机制作为transformer体系中最核心的方法,是NLP、LLM等都绕不开的一部分,多头注意力机制是transformer模型提出的“基石”,分组查询注意力机制是LLaMA2、Qwen等主流大模型对传统多头注意力机制的优化,多查询注意力机制是提升推理速度的高效方法。

因此,本文将对基础的注意力机制、多头注意力机制MHA,及其变体分组查询注意力机制GQA、多查询注意力机制MQA的理论与代码进行剖析,旨在记录学习过程并起到深刻理解的作用。

1.注意力机制

真正弄懂一个模型,一定要知道它是什么、为什么提出、怎么用。

为什么要提出注意力机制
注意力机制的作用是让模型有权重的选择某些信息(就好比看一篇长文章,在关键词句上停留的时间一定比一些助词语气词停留的时间要久,并不是每个字都花同等时间去看)。

注意力机制是什么
注意力机制和核心是围绕三个向量展开:q、k、v。
q(查询)向量:我想要什么(需求、问题、目标)。
k(键)向量:我存有什么候选信息(匹配依据)。
v(值)向量:K 对应的真实内容(最终拿来用的信息)。

注意力机制怎么用
其核心公式为:

用q去匹配所有k,算出权重,再加权抽取对应的v,得到最终向量表示。
Q⋅K:计算Q和K的点积,本质是相似度匹配。点积越大,Q和K的关联越强,模型对这个K对应的V的关注度就越高。
√(dk:缩放因子,防止点积数值过大,导致Softmax后梯度消失。至于为什么是除以根号dk,我在之前的文章中有提到过,如感兴趣可以在这篇文章中查看为什么attention要除以根号dk。
Softmax:将相似度归一化成0~1之间的权重,所有权重和为1,把“相似度”转化为“关注度权重”。
V:用权重对V加权求和,得到最终的注意力输出——关联强的信息权重高,主导输出;无关信息权重低,被弱化。
自注意力是注意力机制的特例,指Q、K、V全部来自同一个输入序列,用于挖掘序列内部的关联(比如句子中“它”指代哪个词),后续在描述不同变种注意力时都采用自注意力的形式。

2.多头注意力机制

多头注意力机制与注意力机制的本质区别是:普通注意力机制的QKV矩阵是用一个,多头注意力机制的QKV矩阵是多个。

为什么要用多头
单头注意力机制只能在一个维度上获取语义特征,比如一个人看文章可能会遗漏重要的信息。将QKV通过线性层拆分为若干个头,每个头分别在低维度上计算注意力机制,最后将所有头进行拼接融合,相当于将一篇文章分给多个人去看。并且多头注意力机制与单头注意力机制总计算量相差并不大(多一个不同头融合的操作),但是多头注意力机制的表达能力大幅提高。

实现方法
现有输入序列的维度为batch_size×seq_len×d_model,其中batch_size是批次大小,seq_len是序列长度,d_model是token的嵌入维度,h是注意力头数(需满足d_model能被h整除),具体流程如下:
1.线性投影:将Q、K、V分别通过3个独立的线性层(权重矩阵分别为d_model×d_model,得到投影后的Q、K、V,维度仍为batch_size×seq_len×d_model。
2.拆分多头:将投影后的Q、K、V拆分成h个独立的头,每个头的维度为d_k=d_model/h,维度转换为batch_size×h×seq_len×d_k。

为什么head与seq_len要交换维度?
注意力本质是建模不同注意力头中每个token与其他token之间的语义关系,若head与seq_len不变换维度则变为建模token内head_dim之间的语义关系,丧失原有设计意义。

3.并行计算注意力:对每个头,独立执行缩放点积注意力计算,得到每个头的输出batch_size×h×seq_len×d_k。
4.拼接头输出:将h个头的输出拼接起来,维度还原为batch_size×seq_len×d_model。
5.最终线性融合:通过一个线性层(权重矩阵W_O:d_model×d_model),对拼接后的结果进行融合,得到最终输出batch_size×h×seq_len×d_model。

代码

classMultiHeadAttention(nn.Module):def__init__(self,d_model,num_heads,dropout=0.1):super(MultiHeadAttention,self).__init__()# d_model必须能被num_heads整除,否则每个头的维度不相等assertd_model%num_heads==0,"d_model must be divisible by num_heads"self.d_model=d_model# 总嵌入维度self.num_heads=num_heads# 注意力头数self.d_k=d_model//num_heads# 每个头的维度# 定义Q、K、V的线性投影层(3个独立线性层)self.wq=nn.Linear(d_model,d_model)self.wk=nn.Linear(d_model,d_model)self.wv=nn.Linear(d_model,d_model)# 定义最终的输出线性层self.wo=nn.Linear(d_model,d_model)# 定义dropout层self.dropout=nn.Dropout(dropout)# 定义层归一化(可选,提升训练稳定性)self.norm=nn.LayerNorm(d_model)defscaled_dot_product_attention(self,q,k,v,mask=None):# 1. 计算Q和K的点积(相似度),维度变为(batch_size, num_heads, seq_len_q, seq_len_k)attn_scores=torch.matmul(q,k.transpose(-2,-1))# 2. 缩放:除以sqrt(d_k),防止点积过大导致Softmax梯度消失attn_scores=attn_scores/math.sqrt(self.d_k)# 3. 掩码:将需要屏蔽的位置设为极小值,Softmax后趋近于0ifmaskisnotNone:attn_scores=attn_scores.masked_fill(mask==0,-1e9)# 4. Softmax计算注意力权重,维度不变attn_weights=torch.softmax(attn_scores,dim=-1)# 5. 应用dropoutattn_weights=self.dropout(attn_weights)# 6. 权重加权求和V,得到注意力输出,维度(batch_size, num_heads, seq_len_q, d_k)output=torch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,mask=None):batch_size=q.size(0)# 步骤1:线性投影Q、K、Vq_proj=self.wq(q)# (batch_size, seq_len_q, d_model)k_proj=self.wk(k)# (batch_size, seq_len_k, d_model)v_proj=self.wv(v)# (batch_size, seq_len_v, d_model)# 步骤2:拆分多头(维度转换:(batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_k))q_proj=q_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)k_proj=k_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)v_proj=v_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# 步骤3:并行计算缩放点积注意力attn_output,attn_weights=self.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)# 步骤4:拼接头输出(维度转换:(batch_size, num_heads, seq_len, d_k) -> (batch_size, seq_len, d_model))attn_output=attn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)# contiguous()确保张量内存连续,避免view报错# 步骤5:最终线性融合 + dropout + 残差连接(可选,提升训练稳定性)output=self.wo(self.dropout(attn_output))output=self.norm(output+q)# 残差连接:输出 + 原始输入qreturnoutput,attn_weights

3.多查询注意力机制

多头注意力机制可以捕获不同子空间的特征,但是每个头都需要独立的q、k、v线性层投影,并且随着序列长度的增加,kv cache需要存储大量信息,增加了计算开销。
正是因此多头注意力机制的这些缺点,因此衍生出了多查询注意力机制MQA,所有注意力头共享一套K和V的投影权重,只保留每个头独立的Q投影权重。
对比多头注意力的区别:多头注意力中h个头有h组Q、K、V;而多查询注意力中h个头只有h组Q,却只有1组K、V——相当于“多个医生会诊,但所有人共用一套检查报告(K、V)”,大幅减少了参数冗余和显存占用。Q负责“从不同角度查询”,K、V负责“提供候选信息和实际内容”,共享K、V并不会显著影响模型的表达能力(因为Q的多样性已经能覆盖不同的查询角度),但能极大降低KV Cache的开销(只需要缓存1组K、V,而不是h组)。

实现方法
它的实现方法与多头注意力基本一致,只是在线性投影和拆分多头时有差异,具体流程:
1.线性投影:通过h个独立的线性层(或1个大线性层拆分),得到h组Q,维度为batch_size× seq_len×d_model。通过1个线性层,得到1组K,维度batch_size×seq_len×d_k(d_k = d_model/h)。通过1个线性层,得到1组V,维度batch_size×seq_len×d_k。
2. 拆分多头:Q会拆分成h个独立的头,维度batch_size×h×seq_len×d_k。K、V则不需要拆分,直接复制h份(或通过广播机制),维度batch_size×h×seq_len×d_k(和Q的维度匹配,便于并行计算)。
3. 后续步骤(并行计算注意力、拼接头输出、最终线性融合)和MHA完全一致。

代码

classMultiQueryAttention(nn.Module):def__init__(self,d_model,num_heads,dropout=0.1):super(MultiQueryAttention,self).__init__()assertd_model%num_heads==0,"d_model must be divisible by num_heads"self.d_model=d_model self.num_heads=num_heads self.d_k=d_model//num_heads# 【MQA核心差异1】:Q有h组投影权重,K、V只有1组投影权重self.wq=nn.Linear(d_model,d_model)# Q:h组权重(通过后续拆分实现)self.wk=nn.Linear(d_model,self.d_k)# K:1组权重,输出维度为d_k(单个头的维度)self.wv=nn.Linear(d_model,self.d_k)# V:1组权重,输出维度为d_kself.wo=nn.Linear(d_model,d_model)self.dropout=nn.Dropout(dropout)self.norm=nn.LayerNorm(d_model)defscaled_dot_product_attention(self,q,k,v,mask=None):attn_scores=torch.matmul(q,k.transpose(-2,-1))attn_scores=attn_scores/math.sqrt(self.d_k)ifmaskisnotNone:attn_scores=attn_scores.masked_fill(mask==0,-1e9)attn_weights=torch.softmax(attn_scores,dim=-1)attn_weights=self.dropout(attn_weights)output=torch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,mask=None):batch_size=q.size(0)# 步骤1:线性投影(【MQA核心差异2】:K、V只做1组投影)q_proj=self.wq(q)# (batch_size, seq_len_q, d_model)k_proj=self.wk(k)# (batch_size, seq_len_k, d_k) —— 1组Kv_proj=self.wv(v)# (batch_size, seq_len_v, d_k) —— 1组V# 步骤2:拆分多头(【MQA核心差异3】:K、V复制h份,与Q匹配)# Q拆分:和MHA一致,(batch_size, seq_len_q, d_model) -> (batch_size, num_heads, seq_len_q, d_k)q_proj=q_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# K、V复制h份:(batch_size, seq_len_k, d_k) -> (batch_size, num_heads, seq_len_k, d_k)# 用广播机制实现复制,避免冗余计算(更高效)# unsqueeze是在第1维添加一个维度,变为batch_size, 1, seq_len_k, d_k。repeat是将第1维复制为num_heads份,其他维度保持不变。k_proj=k_proj.unsqueeze(1).repeat(1,self.num_heads,1,1)v_proj=v_proj.unsqueeze(1).repeat(1,self.num_heads,1,1)# 步骤3-5:和MHA完全一致attn_output,attn_weights=self.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)attn_output=attn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)output=self.wo(self.dropout(attn_output))output=self.norm(output+q)returnoutput,attn_weights

4.分组查询注意力机制

虽然多查询注意力机制很大程度上解决了多头注意力机制的计算开销大、随序列长度的增加推理速度慢,但是其表达能力会有损失:共享K、V会导致不同头的注意力计算依赖同一套候选信息,可能会丢失部分细粒度特征;在部分任务(如细粒度语义理解)中,可能会出现训练震荡,需要调参优化。
所以Google在2023年提出了一种介于两者之间的全新注意力机制分组查询注意力机制,它的做法是将h个Q头分成G组,每组共享一套K和V的投影权重——既不像MHA那样每个头都有独立K、V(开销大),也不像MQA那样所有头共享一套K、V(表达能力损失),实现了“表达能力”和“推理效率”的最优平衡。假设h=8(Q头数),g=2(分组数),那么每4个Q头为一组,每组共享1套K、V,总共需要2套K、V——KV Cache的开销是MHA的1/4(2/8),远低于MHA,同时表达能力比MQA更强(多组K、V能捕捉更多细粒度特征)。

GQA有三种变体:
GQA-1:一个单独的组,等同于 Multi-Query Attention (MQA)。
GQA-H:组数等于头数,基本上与 Multi-Head Attention (MHA) 相同。
GQA-G:一个中间配置,具有G个组,平衡了效率和表达能力。

实现过程

GQA的流程在MQA基础上增加了分组步骤,具体如下:
1.线性投影:

现有输入序列的维度为batch_size×seq_len×d_model,其中batch_size是批次大小,seq_len是序列长度,d_model是token的嵌入维度,h是注意力头数(需满足d_model能被h整除),具体流程如下:
1.线性投影:将Q、K、V分别通过3个独立的线性层(权重矩阵分别为d_model×d_model、d_model×(d_model // num_heads)*group、d_model×(d_model // num_heads)*group,得到投影后的Q、K、V,QKV的维度分别为batch_size×seq_len×d_model、batch_size×seq_len×(d_model // num_heads)*group、batch_size×seq_len×(d_model // num_heads)*group。
2.拆分多头:将投影后的Q、K、V拆分成不同的组和头,Q、K、V维度转换为batch_size×group, head//group, seq_len, d_k、batch_size×group, 1, seq_len, d_k、batch_size×group, 1, seq_len, d_k。通过广播机制,对KV中的头数自动扩展为对应维度的长度(此处1扩展为h/g),实现h/g个Q头共享1套KV,既高效又节省显存。
3.并行计算注意力:对每个头,独立执行缩放点积注意力计算,得到每个头的输出batch_size×group×h//group×seq_len×d_k。
4.拼接头输出:将h个头的输出拼接起来,维度还原为batch_size×seq_len×d_model。
5.最终线性融合:通过一个线性层(权重矩阵W_O:d_model×d_model),对拼接后的结果进行融合,得到最终输出batch_size×h×seq_len×d_model。

代码

classGroupedQueryAttention(nn.Module):def__init__(self,d_model,num_heads,num_kv_heads,dropout=0.1):super(GroupedQueryAttention,self).__init__()# 确保d_model能被num_heads整除(保证每个头维度d_k为整数)assertd_model%num_heads==0,"d_model must be divisible by num_heads"# 确保num_heads能被num_kv_heads整除(保证每组Q头数为整数)assertnum_heads%num_kv_heads==0,"num_heads must be divisible by num_kv_heads"self.d_model=d_model self.num_heads=num_heads# Q头数(h)self.num_kv_heads=num_kv_heads# K、V头数(分组数g)self.d_k=d_model//num_heads# 每个头的维度(d_k = d_model/h)self.heads_per_group=num_heads//num_kv_heads# 每组的Q头数(h/g)# Q、K、V线性层,契合通用流程的投影逻辑self.wq=nn.Linear(d_model,d_model)# 等价于nn.Linear(d_model, num_heads×d_k)self.wk=nn.Linear(d_model,self.num_kv_heads*self.d_k)# 输出g×d_kself.wv=nn.Linear(d_model,self.num_kv_heads*self.d_k)# 输出g×d_k# 最终线性融合层(与MHA一致)self.wo=nn.Linear(d_model,d_model)self.dropout=nn.Dropout(dropout)self.norm=nn.LayerNorm(d_model)# 复用缩放点积注意力子模块(与MHA完全一致)defscaled_dot_product_attention(self,q,k,v,mask=None):attn_scores=torch.matmul(q,k.transpose(-2,-1))attn_scores=attn_scores/math.sqrt(self.d_k)ifmaskisnotNone:attn_scores=attn_scores.masked_fill(mask==0,-1e9)attn_weights=torch.softmax(attn_scores,dim=-1)attn_weights=self.dropout(attn_weights)output=torch.matmul(attn_weights,v)returnoutput,attn_weightsdefforward(self,q,k,v,mask=None):batch_size=q.size(0)# 步骤1:线性投影(契合通用流程)q_proj=self.wq(q)# (batch_size, seq_len_q, d_model) → (bs, sl_q, h×d_k)k_proj=self.wk(k)# (batch_size, seq_len_k, num_kv_heads * d_k) → (bs, sl_k, g×d_k)v_proj=self.wv(v)# (batch_size, seq_len_v, num_kv_heads * d_k) → (bs, sl_v, g×d_k)# 步骤2:拆分多头与分组(契合通用流程的维度变化)q_proj=q_proj.view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)q_proj=q_proj.view(batch_size,self.num_kv_heads,self.heads_per_group,-1,self.d_k)k_proj=k_proj.view(batch_size,-1,self.num_kv_heads,self.d_k).transpose(1,2).unsqueeze(2)v_proj=v_proj.view(batch_size,-1,self.num_kv_heads,self.d_k).transpose(1,2).unsqueeze(2)# 步骤3:并行计算分组注意力attn_output,attn_weights=self.scaled_dot_product_attention(q_proj,k_proj,v_proj,mask)# 步骤4:拼接头输出,还原维度attn_output=attn_output.view(batch_size,self.num_heads,-1,self.d_k)attn_output=attn_output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)# 步骤5:最终线性融合+残差连接+层归一化output=self.wo(self.dropout(attn_output))output=self.norm(output+q)returnoutput,attn_weights

5.三者对比

三者注意力机制的对比如下:

对比维度多头注意力(MHA)分组查询注意力(GQA)多查询注意力(MQA)
核心特点每个Q头有独立的K、V头g个分组,每组Q头共享1套K、V头所有Q头共享1套K、V头
K/V头数等于Q头数(h)分组数(g,1<g<h)1
显存开销(KV Cache)最大(h组K、V)中等(g组K、V)最小(1组K、V)
推理速度最慢较快(接近MQA)最快
表达能力最强较强(接近MHA)较弱
实现复杂度中等较高(需分组)最低
训练稳定性最高较高较低
代表模型BERT、GPT-2、T5LLaMA 2/3、Mixtral、QwenFalcon、SantaCoder、StarCoder
适用场景对表达能力要求高,不计较推理速度(如小模型训练、细粒度任务)兼顾性能和效率(主流大模型、企业级部署)追求极致推理速度,允许轻微性能损失(端侧部署、长序列生成)
http://www.jsqmd.com/news/863499/

相关文章:

  • Windows Btrfs驱动完全指南:解锁Linux文件系统的7大核心优势
  • 新能源车辆数据处理平台架构
  • 告别克隆整个仓库:GitHub文件精准下载工具使用指南
  • Go 闭包【1】基础
  • 告别焦虑等待!Elsevier投稿状态自动追踪插件,让你的科研进度一目了然
  • 调用外部服务却无监控?这可能是下一个雪崩的源头
  • ContentBranch+CFBranch混合电影推荐模型|全网独家复现,深度学习实战篇 引入双分支融合架构,兼顾内容特征与协同信号、助力冷启动缓解、数据稀疏性优化、推荐精度有效涨点
  • 【硬件面试题精讲】运放求和 + 同相放大电路输出计算(附原理与通用公式)
  • 淘金币自动化脚本:5分钟搞定淘宝每日任务,轻松解放双手
  • 苏州德奥诚汽车服务:太仓靠谱的报废车回收推荐哪几家 - LYL仔仔
  • Go闭包【2】 1.22 对 for 循环里闭包陷阱的那个“史诗级更新”
  • HoRain云--AI 底层架构
  • QQ音乐加密文件终极转换指南:3步将.qmc文件转为MP3/FLAC
  • 达梦数据库-堆栈看问题-01-asmapi_asm_extent_load
  • 如何在Windows上实现专业级游戏控制器模拟:ViGEmBus驱动深度解析
  • DS4Windows终极指南:如何在Windows上完美使用PS4/PS5手柄玩所有游戏
  • Warcraft Helper:现代Windows环境下魔兽争霸3兼容性技术解决方案深度解析
  • TranslucentTB:Windows任务栏透明化终极指南与5大创意应用场景
  • 你的 BroadcastReceiver 为何在后台装死?—— Android 8.0+ 隐式广播限制与动态注册完全指南
  • 苏州购宠避坑指南|5 家靠谱实体门店 - 资讯速览
  • 2026年5月论文降 AI 率工具终极推荐:超过一半学生的选择,早标网为何降AI效果好? - 全维度降AI
  • 10.Python 迭代器、生成器与装饰器 深度解析
  • 3分钟快速上手SketchUp STL插件:终极3D打印模型转换完整指南
  • [MAF的Agent管道详解-04]如何让LLM按照要求的结构输出数据?
  • 浏览器资源嗅探革命:猫抓扩展如何重新定义在线媒体捕获体验
  • 如何快速安装BetterNCM:终极网易云音乐插件管理指南
  • 深度解析Unity游戏实时翻译插件:XUnity.AutoTranslator的5大实战应用场景与架构设计
  • 大学买不到GPU怪我?黄仁勋斯坦福现场火力全开:是你们体制的错!
  • Sub2API + CCSwitch 实现 Codex 反向代理:多账号流量分发实战(解决codex手机号验证)可以润色吗
  • 【紧急更新】Midjourney v6.1金属纹理算法变更预警:3个必须重训的材质参数阈值,错过将导致PBR贴图链断裂