Transformer 核心模块详解:多头注意力、前馈网络与词嵌入
【学习记录】Transformer 核心模块详解:多头注意力、前馈网络与词嵌入
Transformer 是现代大语言模型的基石,而多头注意力(MultiHeadAttention)、前馈网络(FFN)和词嵌入(Embedding)是其最核心的三个组件。本文从原理到代码,逐层拆解这三个模块,并提供 Python(PyTorch)和 C++(LibTorch)实现,附带完整的复杂度分析。
📌 目录
- MultiHeadAttention(多头注意力)
- FFN(前馈网络)
- Embedding(词嵌入)
- 三个模块的组合使用
- 复杂度总结
一、多头注意力(MultiHeadAttention)
1.1 作用
多头注意力机制允许模型同时关注输入序列中不同位置的不同表示子空间。它通过将查询(Q)、键(K)、值(V)线性映射到多个头,分别计算注意力,最后拼接并映射回原维度。
1.2 数学公式
标准缩放点积注意力:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
多头注意力:
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)。
1.3 代码实现(Python/PyTorch)
importtorchimporttorch.nnasnnimportmathclassMultiHeadAttention(nn.Module):def__init__(self,d_model,num_heads):super().__init__()assertd_model%num_heads==0self.d_model=d_model self.num_heads=num_heads self.d_k=d_model//num_heads self.Wq=nn.Linear(d_model,d_model)self.Wk=nn.Linear(d_model,d_model)self.Wv=nn.Linear(d_model,d_model)self.Wo=nn.Linear(d_model,d_model)defforward(self,Q,K,V,mask=None):batch_size=Q.size(0)# 1. 线性映射并拆分为多头Q=self.Wq(Q).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)K=self.Wk(K).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)V=self.Wv(V).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# 2. 缩放点积注意力scores=torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.d_k)# 3. 应用掩码(可选)ifmaskisnotNone:scores=scores.masked_fill(mask==0,-1e9)attn_weights=torch.softmax(scores,dim=-1)output=torch.matmul(attn_weights,V)# 4. 合并多头并输出output=output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)returnself.Wo(output)1.4 图解(文本示意)
输入: (B, T, D) │ ├─→ 线性映射 Wq, Wk, Wv → (B, T, D) │ ├─→ view + transpose → (B, n_head, T, d_k) │ ├─→ scores = Q @ K^T / sqrt(d_k) → (B, n_head, T, T) │ │ │ └─→ mask (可选) 填充 -1e9 │ ├─→ softmax → (B, n_head, T, T) │ ├─→ output = attn @ V → (B, n_head, T, d_k) │ ├─→ transpose + view → (B, T, D) │ └─→ Wo 线性映射 → (B, T, D)1.5 C++ 代码(LibTorch)
#include<torch/torch.h>classMultiHeadAttentionImpl:publictorch::nn::Module{public:intd_model,num_heads,d_k;torch::nn::Linear Wq,Wk,Wv,Wo;MultiHeadAttentionImpl(intd_model_,intnum_heads_):d_model(d_model_),num_heads(num_heads_),d_k(d_model_/num_heads_),Wq(torch::nn::Linear(d_model,d_model)),Wk(torch::nn::Linear(d_model,d_model)),Wv(torch::nn::Linear(d_model,d_model)),Wo(torch::nn::Linear(d_model,d_model)){register_module("Wq",Wq);register_module("Wk",Wk);register_module("Wv",Wv);register_module("Wo",Wo);}torch::Tensorforward(torch::Tensor Q,torch::Tensor K,torch::Tensor V,torch::Tensor mask={}){intbatch_size=Q.size(0);// 线性映射Q=Wq->forward(Q).view({batch_size,-1,num_heads,d_k}).transpose(1,2);K=Wk->forward(K).view({batch_size,-1,num_heads,d_k}).transpose(1,2);V=Wv->forward(V).view({batch_size,-1,num_heads,d_k}).transpose(1,2);// 注意力分数autoscores=torch::matmul(Q,K.transpose(-2,-1))/std::sqrt(d_k);if(mask.defined()){scores=scores.masked_fill(mask==0,-1e9);}autoattn=torch::softmax(scores,-1);autooutput=torch::matmul(attn,V);output=output.transpose(1,2).contiguous().view({batch_size,-1,d_model});returnWo->forward(output);}};TORCH_MODULE(MultiHeadAttention);1.6 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 线性映射 (Q,K,V) | O(B×T×D²) | O(B×T×D) |
| 拆分多头 | O(B×T×D) | O(B×n_head×T×d_k) |
| 分数矩阵乘法 | O(B×n_head×T²×d_k) | O(B×n_head×T²) |
| Softmax | O(B×n_head×T²) | O(B×n_head×T²) |
| 加权求和 | O(B×n_head×T²×d_k) | O(B×n_head×T×d_k) |
| 合并与输出映射 | O(B×T×D²) | O(B×T×D) |
| 总计 | O(B × T² × D) | O(B × n_head × T²) |
其中
D = d_model,d_k = D / n_head。
二、前馈网络(FFN)
2.1 作用
FFN 对每个位置独立进行非线性变换,增加模型表达能力。标准结构:线性 → ReLU → 线性,通常中间维度d_ff是d_model的 4 倍左右。
2.2 数学公式
FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2FFN(x)=ReLU(xW1+b1)W2+b2
2.3 代码实现(Python/PyTorch)
classFFN(nn.Module):def__init__(self,d_model,d_ff):super().__init__()self.linear1=nn.Linear(d_model,d_ff)self.linear2=nn.Linear(d_ff,d_model)self.activation=nn.ReLU()defforward(self,x):returnself.linear2(self.activation(self.linear1(x)))2.4 图解
输入 (B, T, D) │ ├─→ linear1 (D → d_ff) → (B, T, d_ff) │ ├─→ ReLU → (B, T, d_ff) │ └─→ linear2 (d_ff → D) → (B, T, D)2.5 C++ 代码(LibTorch)
classFFNImpl:publictorch::nn::Module{public:torch::nn::Linear linear1,linear2;FFNImpl(intd_model,intd_ff):linear1(d_model,d_ff),linear2(d_ff,d_model){register_module("linear1",linear1);register_module("linear2",linear2);}torch::Tensorforward(torch::Tensor x){returnlinear2->forward(torch::relu(linear1->forward(x)));}};TORCH_MODULE(FFN);2.6 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| linear1 | O(B × T × D × d_ff) | O(B × T × d_ff) |
| ReLU | O(B × T × d_ff) | O(B × T × d_ff) |
| linear2 | O(B × T × d_ff × D) | O(B × T × D) |
| 总计 | O(B × T × D × d_ff) | O(B × T × max(D, d_ff)) |
当
d_ff = 4 × D时,复杂度约为O(4 × B × T × D²)。
三、词嵌入(Embedding)
3.1 作用
将离散的 token ID 序列映射为稠密的连续向量,并乘以√d_model进行缩放,以便与位置编码相加时尺度匹配。
3.2 代码实现(Python/PyTorch)
classEmbedding(nn.Module):def__init__(self,vocab_size,d_model):super().__init__()self.embedding=nn.Embedding(vocab_size,d_model)self.d_model=d_modeldefforward(self,x):returnself.embedding(x)*math.sqrt(self.d_model)3.3 图解
输入: (B, T) token IDs [ [1, 3, 0, ...] ] │ └─→ nn.Embedding 查表 (vocab_size × D) │ └─→ 输出 (B, T, D) │ └─→ 乘以 √D → (B, T, D)3.4 C++ 代码(LibTorch)
classEmbeddingImpl:publictorch::nn::Module{public:torch::nn::Embedding embedding;intd_model;EmbeddingImpl(intvocab_size,intd_model_):embedding(vocab_size,d_model_),d_model(d_model_){register_module("embedding",embedding);}torch::Tensorforward(torch::Tensor x){returnembedding->forward(x)*std::sqrt(d_model);}};TORCH_MODULE(Embedding);3.5 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 查表 | O(B × T) | O(B × T × D) |
| 乘法 | O(B × T × D) | O(B × T × D) |
| 总计 | O(B × T × D) | O(B × T × D) |
四、三个模块的组合使用
一个完整的 Transformer 编码器层通常由多头注意力 + 残差连接 + 层归一化和FFN + 残差连接 + 层归一化构成。
classTransformerEncoderLayer(nn.Module):def__init__(self,d_model,num_heads,d_ff):super().__init__()self.self_attn=MultiHeadAttention(d_model,num_heads)self.ffn=FFN(d_model,d_ff)self.norm1=nn.LayerNorm(d_model)self.norm2=nn.LayerNorm(d_model)defforward(self,x,mask=None):# 自注意力 + 残差 + 层归一化attn_out=self.self_attn(x,x,x,mask)x=self.norm1(x+attn_out)# FFN + 残差 + 层归一化ffn_out=self.ffn(x)x=self.norm2(x+ffn_out)returnx完整流程示例
vocab_size=10000d_model=512num_heads=8d_ff=2048batch_size=2seq_len=10# 输入 token IDsinput_ids=torch.randint(0,vocab_size,(batch_size,seq_len))# 嵌入层embed=Embedding(vocab_size,d_model)x=embed(input_ids)# (2,10,512)# 位置编码(此处略,可加上)# pos_enc = PositionalEncoding(d_model)# x = pos_enc(x)# Transformer 编码器层encoder_layer=TransformerEncoderLayer(d_model,num_heads,d_ff)output=encoder_layer(x)# (2,10,512)print(output.shape)# torch.Size([2, 10, 512])五、复杂度总结
| 模块 | 时间复杂度 | 空间复杂度 | 说明 |
|---|---|---|---|
| MultiHeadAttention | O(B × T² × D) | O(B × n_head × T²) | 核心瓶颈在 T²,长序列需优化 |
| FFN | O(B × T × D × d_ff) | O(B × T × max(D, d_ff)) | 通常 d_ff = 4D,复杂度约为 4× |
| Embedding | O(B × T × D) | O(B × T × D) | 查表操作,轻量级 |
优化建议:
- 对于长序列(T 很大),可使用稀疏注意力(如 FlashAttention)降低 T² 复杂度。
- FFN 的中间维度 d_ff 越大模型容量越大,但计算量线性增加。
- 嵌入层占参数量主要为
vocab_size × D,大词表时需考虑参数共享或压缩。
🎯 总结
本文详细拆解了 Transformer 的三个核心模块:
- 多头注意力:让模型关注不同位置的多种关系,是 Transformer 成功的核心。
- 前馈网络:提供非线性变换,增强模型表达能力。
- 词嵌入:将离散符号映射到连续空间,是深度学习处理文本的起点。
通过理解这些模块的输入输出、形状变化和复杂度,能轻松搭建并优化自己的 Transformer 模型。
