当前位置: 首页 > news >正文

浅析注意力(Attention)机制(三)-- Multi-Head Attention多头自注意力机制

上期我们讲完了针对序列中单个元素的自注意力机制计算方法,为了更方便理解多头自注意力机制,我们先将自注意力机制改写为矩阵形式。

3.1 Self-Attention的矩阵形式

在实际计算中,序列的各个元素并行输入到自注意力机制中,将所有位置的向量堆叠起来

\[\mathbf{X}= \begin{bmatrix} \left(\mathbf{x}_{(1)}\right)^T \\ \left(\mathbf{x}_{(2)}\right)^T \\ \vdots \\ \left(\mathbf{x}_{(n)}\right)^T \end{bmatrix}\in\mathbb{R}^{n\times d}\]

\(Q、K、V\)三参数的计算方法也可以改写为矩阵形式,即:

\[\begin{aligned} \mathbf{Q} & =\mathbf{XW}_Q\in\mathbb{R}^{n\times d_k} \\ \mathbf{K} & =\mathbf{XW}_K\in\mathbb{R}^{n\times d_k} \\ \mathbf{V} & =\mathbf{XW}_V\in\mathbb{R}^{n\times d_v} \end{aligned}\]

同理,相关性矩阵可由以下公式计算:

\[\mathbf{S} = \frac{\mathbf{Q}\mathbf{K}^{T}}{\sqrt{d_k}} \]

最后经过softmax归一化得到注意力权重,再和\(V\)加权求和,这也是自注意力机制的核心公式:

\[\mathbf{Z} = Attention(\mathbf{Q}, \mathbf{K},\mathbf{V}) = softmax(\frac{\mathbf{Q}\mathbf{K}^{T}}{\sqrt{d_k}})V \]

以上就是自注意力机制的矩阵形式,但单头的自注意力机制只能用一种视角给序列建模,我们自然而然能联想到,能否用并行的方式计算出多个注意力权重,然后融合起来呢?这就是多头自注意力机制的基本思想。

3.2 计算过程

正所谓“一生二、二生三、三生万物”,多头自注意力机制的核心思路很直观,就是通过多个注意力头并行,计算出一组注意力权重,每次关注不同的子空间,这样模型就能从多个视角挖掘序列中各个元素的依赖关系,即:

\[\mathbf{Z}_1, \mathbf{Z}_2, ... , \mathbf{Z}_h \]

随后将这一组注意力权重直接拼接,没错,就是直接拼起来,得到一个更大的注意力矩阵

\[\mathbf{Z} = Concat(\mathbf{Z}_1, \mathbf{Z}_2, ... , \mathbf{Z}_h) \]

随后将拼接后的矩阵输入一个线性层,因为拼接后的注意力矩阵只是将不同注意力头的信息并排放在一起,来自不同注意力头的权重之间并没有发现关联,因此要实现真正的融合,必须进行一次线性变换,即:

\[Z_{final} = \mathbf{W}_O\mathbf{Z} \]

由此,所有头的信息被打散并重新组合,模型可以自由地学习跨多头的特征关系。

把公式合并如下,得到多头注意力的融合公式:

\[Z_{final} = Multihead(\mathbf{Q}, \mathbf{K}, \mathbf{V})$$ = Concat(\mathbf{Z}_1, \mathbf{Z}_2, ... , \mathbf{Z}_h)W_O \]

3.3 维度关系

我们知道,Transformer 中的注意力机制是多层叠加使用的,为了让模型中的数据流动更加简洁,最好让向量流过注意力机制后维度保持不变

\[d_k=d_v=\frac{d}{h} \]

例如输入的编码\(\mathbf{x_{(i)}}\)位数为512,注意力头8个,则每个注意力头计算出的注意力维数为512/8=64,这个整除关系必须满足,那么\(d_k, d_v\)长度均为64,正好64是完全平方数,方便计算注意力分数时进行开方操作。

3.4 代码实现

import torch
import torch.nn as nn
import copy, math
import torch.nn.functional as F
from torch.autograd import Variableclass Encoder(nn.Module):def __init__(self, layer, N):super(Encoder, self).__init__()self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])self.norm = nn.LayerNorm(layer.size)def forward(self, x, mask):"""Pass the input (and mask) through each layer in turn."""for layer in self.layers:x = layer(x, mask)return self.norm(x)class SublayerConnection(nn.Module):"""A residual connection followed by a layer norm.Note for code simplicity the norm is first as opposed to last."""def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = nn.LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):"""Apply residual connection to any sublayer with the same size."""return x + self.dropout(sublayer(self.norm(x)))class EncoderLayer(nn.Module):def __init__(self, size, num_heads, feed_forward, dropout):super(EncoderLayer, self).__init__()self.self_attn = nn.MultiheadAttention(size, num_heads)self.feed_forward = feed_forwardself.sublayer = nn.ModuleList([SublayerConnection(size, dropout) for _ in range(2)])self.size = sizedef forward(self, x, mask):x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))return self.sublayer[1](x, self.feed_forward)class Decoder(nn.Module):def __init__(self, layer, N):super(Decoder, self).__init__()self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])self.norm = nn.LayerNorm(layer.size)def forward(self, x, memory, src_mask, tgt_mask):for layer in self.layers:x = layer(x, memory, src_mask, tgt_mask)return self.norm(x)class DecoderLayer(nn.Module):def __init__(self, size, num_heads, feed_forward, dropout):super(DecoderLayer, self).__init__()self.self_attn = nn.MultiheadAttention(size, num_heads)self.src_attn = nn.MultiheadAttention(size, num_heads)self.feed_forward = feed_forwardself.sublayer = nn.ModuleList([SublayerConnection(size, dropout) for _ in range(3)])self.size = sizedef forward(self, x, memory, src_mask, tgt_mask):x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))x = self.sublayer[1](x, lambda x: self.src_attn(x, memory, memory, src_mask))return self.sublayer[2](x, self.feed_forward)class PositionwiseFeedForward(nn.Module):"""Implements FFN equation."""def __init__(self, d_model, d_ff, dropout=0.1):super(PositionwiseFeedForward, self).__init__()self.w_1 = nn.Linear(d_model, d_ff)self.w_2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.w_2(self.dropout(F.relu(self.w_1(x))))class Embeddings(nn.Module):def __init__(self, d_model, vocab):super(Embeddings, self).__init__()self.embeddings = nn.Embedding(vocab, d_model)self.d_model = d_modeldef forward(self, x):return self.embeddings(x)*math.sqrt(self.d_model)class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2)*-(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):x = x + Variable(self.pe[:, :x.size(1)],requires_grad=False)return self.dropout(x)class Generator(nn.Module):"""Define standard linear + softmax generation step."""def __init__(self, d_model, vocab):super(Generator, self).__init__()self.proj = nn.Linear(d_model, vocab)def forward(self, x):return F.log_softmax(self.proj(x), dim=-1)class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):super(EncoderDecoder, self).__init__()self.encoder = encoderself.decoder = decoderself.src_embed = src_embedself.tgt_embed = tgt_embedself.generator = generatordef forward(self, src, tgt, src_mask, tgt_mask):"""Take in and process masked src and target sequences."""return self.decode(self.encode(src, src_mask), src_mask,tgt, tgt_mask)def encode(self, src, src_mask):return self.encoder(self.src_embed(src), src_mask)def decode(self, memory, src_mask, tgt, tgt_mask):return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)def make_model(src_vocab, tgt_vocab, d_model=512, N=6, h=8, d_ff=2048, dropout=0.1):c = copy.deepcopyff = PositionwiseFeedForward(d_model, d_ff, dropout=dropout)position = PositionalEncoding(d_model, dropout=dropout)# 编码长度embeddings512,注意力头head8个,则每个注意力头计算出的注意力分数维数为512/8=64,这个整除关系必须满足# embeddings和head的商必须是完全平方数,因为计算注意力分数时需要进行开方操作# 多头注意力就是将所有注意力头的计算结果拼接以后再进行线性变换的操作,最终输出的矩阵与输入相同,这样才可以堆叠model = EncoderDecoder(Encoder(EncoderLayer(d_model, h, c(ff), dropout), N),Decoder(DecoderLayer(d_model, h, c(ff), dropout), N),nn.Sequential(Embeddings(d_model, src_vocab), c(position)),nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),Generator(d_model, tgt_vocab))# This was important from their code.# Initialize parameters with Glorot / fan_avg.for p in model.parameters():if p.dim() > 1:nn.init.xavier_uniform(p)return model
http://www.jsqmd.com/news/761861/

相关文章:

  • 实验室安全管理与操作效率提升实践指南
  • 2025届最火的六大降AI率网站推荐
  • java小白福音:用快马ai生成带注释的入门代码,轻松理解jdk核心
  • ClawApp爬虫框架:从零构建工程化数据采集应用
  • WinDbg的使用方法(分析蓝屏原因)
  • 家电口碑战怎么拆评论
  • 深入解析Cappuccino:现代前端状态逻辑管理框架的设计与实践
  • 2026年4月靠谱的橡胶垫板供应商口碑推荐,压轨器/轨距挡板/橡胶垫板/轨道压板/螺旋道钉,橡胶垫板订做厂家怎么选择 - 品牌推荐师
  • 用STM32 HAL库驱动WS2812B:从CubeMX配置到流水灯效果,一个视频全搞定(F103C8T6+PWM+DMA)
  • SSH终端集成AI助手:构建智能命令行副驾驶的实践指南
  • aicommit2:基于AI的Git提交信息自动生成工具实践指南
  • PySpark DataFrame实战:从CSV文件到SQL式分析,一条龙搞定用户画像分析
  • 国内主流隔油池源头厂家实力排行实测盘点:隔油提升一体化设备厂家/隔油提升设备/食品厂污水处理设备/食品厂油水分离器/选择指南 - 优质品牌商家
  • 别再让触摸板失灵了!FPC柔性电路板布线避坑指南(附PCB设计实例)
  • Packforge:声明式构建编排工具,统一多项目CI/CD流程
  • 2026年玻璃钢排水渠优质产品推荐榜:玻璃钢罐体、玻璃钢运输罐、高速急流槽、u型排水沟、农田灌溉排水渠、化工储罐选择指南 - 优质品牌商家
  • Hadoop核心目录深度解析:架构师必备功能清单及应用场景
  • Vue3——使用Mock.js
  • 效率倍增:用快马平台一键生成优化版dfs代码框架,告别重复劳动
  • 基于MLP的孪生网络目标跟踪算法研究
  • 嵌入式BIOS开发:硬件初始化与电源管理优化实践
  • 2026年山东大学项目实训项目记录(三)
  • Godot 4多窗口游戏开发:实现角色跨窗口移动与视口共享
  • 2026农业灌溉储水箱优质厂家推荐榜:不锈钢高位消防水箱、二次变频恒压供水设备、二次恒压供水设备、农业灌溉储蓄水箱,选择指南 - 优质品牌商家
  • 告别命令行!用C# Winform给Tibco RV做个可视化调试工具(附源码)
  • 贸易展销实战指南:从展台设计到订单转化的全流程技能拆解
  • LLM红队测试实战:T-MAP提升AI风控3-7倍覆盖率
  • TWIG框架:平衡文本到图像生成的精确控制与创意发散
  • LLM动态网页生成技术:从自然语言到交互界面
  • 开发提速:用快马AI一键生成oh-my-openagent通用工具类代码