从原理到实战:一文彻底吃透Transformer架构
2017年Google发表的《Attention Is All You Need》论文,用仅11页的篇幅彻底改写了深度学习的历史。8年过去了,Transformer从一篇论文变成了驱动ChatGPT、GPT-4、BERT等所有主流大模型的“钢筋骨架”。今天,我们从零开始,把这座AI摩天大楼彻底拆解一遍。
一、为什么需要Transformer?
在Transformer问世之前,NLP领域的主流方案是RNN(循环神经网络)和LSTM(长短期记忆网络)。但它们有两个致命缺陷:
缺陷1:串行计算,训练慢到怀疑人生
RNN处理序列时必须按时间步依次计算——必须先处理第一个词,才能处理第二个词。这种串行模式根本无法利用GPU的并行计算能力。处理长文本时,训练时间会呈指数级增长。
缺陷2:长距离依赖,记性力堪比金鱼
RNN天生有“短期记忆”缺陷。比如翻译“那个戴着红帽子、拿着黑伞的男人走进了商店”这句话,RNN很容易把前面的修饰成分忘得一干二净,最后只输出“一个男人走进了商店”。LSTM和GRU虽然有所缓解,但并没有从根本上解决问题。
CNN虽然可以并行,但建模长距离依赖需要堆叠大量卷积层,参数冗余且效率低下。
正是在这样的背景下,《Attention Is All You Need》横空出世——完全抛弃了RNN和CNN,只靠注意力机制打天下。
二、自注意力机制:Transformer的灵魂
2.1 什么是注意力机制?
注意力机制的核心灵感来自人类的注意力习惯——看图片时会聚焦主体、忽略背景,读文本时会关注关键词、弱化无关词汇。它让模型能自适应地为输入信息分配不同的权重,重点关注重要信息。
所有注意力机制都基于QKV(Query-Key-Value)三元组框架。用查字典来理解:
Query(查询):你想查的问题,比如“注意力机制是什么?”
Key(键):字典里每个词条的标题
Value(值):字典里每个词条的具体内容
你的Query和每个Key算一个“相关度分数”,分数越高,对应的Value就越值得被关注。
2.2 自注意力的数学原理
自注意力(Self-Attention)是指序列内部元素之间的注意力——每个词都要“关注”序列中所有其他词。
给定输入序列 X∈Rn×dX∈Rn×d(nn为序列长度,dd为特征维度),首先通过三个线性层生成Q、K、V矩阵:
import torch import torch.nn as nn class SelfAttention(nn.Module): def __init__(self, d_model): super().__init__() self.d_k = d_model // 8 # 每个头的维度 self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.softmax = nn.Softmax(dim=-1) def forward(self, x): Q = self.W_q(x) # [batch, seq_len, d_model] K = self.W_k(x) V = self.W_v(x) # ... 后续计算
然后计算注意力分数:
Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax(dkQKT)V
这个公式的核心是Q和K的点积——点积越大,说明两个位置的相关性越强。
为什么要除以 dkdk?因为当维度较大时,点积的数值会很大,导致softmax进入梯度饱和区。除以 dkdk 可以让梯度保持在稳定范围。
2.3 多头注意力:一个模型,多种视角
单头注意力只能捕捉一种依赖关系。但语言是复杂的——一个词可能同时和多个词有不同类型的关系。
多头注意力(Multi-Head Attention)将Q、K、V投影到多个子空间(论文中默认8个头),每个头独立计算注意力,最后把所有头的结果拼接起来。
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): batch_size, seq_len, _ = x.size() # 线性变换并拆分为多头 Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # 计算注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) attn = torch.softmax(scores, dim=-1) out = torch.matmul(attn, V) # 合并多头 out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) return self.W_o(out)
多头注意力的妙处在于:不同的头可以关注不同层面的信息。比如在翻译任务中,一个头可能专注语法结构,另一个头关注语义角色。
三、Transformer整体架构
Transformer采用经典的编码器-解码器(Encoder-Decoder)结构:
┌─────────────────┐ ┌─────────────────┐ │ Encoder │ │ Decoder │ │ (N×堆叠) │ │ (N×堆叠) │ └────────┬────────┘ └────────┬────────┘ │ │ └──────────┬───────────────┘ ▼ ┌─────────────────┐ │ 线性输出层 │ └─────────────────┘
原始论文中,编码器和解码器各堆叠6层。
3.1 编码器(Encoder)
每个编码器层包含两个子层:
多头自注意力层:让每个位置关注输入序列中的所有位置
前馈神经网络(FFN):对每个位置的表示做进一步变换
每个子层后面都跟着残差连接和层归一化:
output=LayerNorm(x+Sublayer(x))output=LayerNorm(x+Sublayer(x))
残差连接的作用是缓解梯度消失,让模型可以堆叠得更深。层归一化则通过标准化输入分布来加速训练收敛。
3.2 解码器(Decoder)
解码器比编码器多了一个子层:
掩码多头自注意力:和编码器类似,但要防止看到“未来的词”(用掩码遮住当前位置之后的信息)
交叉注意力:Q来自解码器,K和V来自编码器的输出——让解码器在生成时“参考”源语言的信息
前馈神经网络
3.3 位置编码:没有顺序怎么办?
Transformer的自注意力机制本身不具备感知位置的能力。换句话说,模型会把“我爱你”和“你爱我”当成同一个输入——这显然不行。
为了解决这个问题,Transformer引入了位置编码(Positional Encoding)。原始论文使用正弦和余弦函数:
PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i)=sin(pos/100002i/dmodel)
PE(pos,2i+1)=cos(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)
这种编码方式有三个优点:
确定性:相同位置的编码永远一致
相对关系一致性:任意两个位置的相对距离关系保持一致
泛化能力:可以处理比训练时更长的序列
最终的输入表示是词嵌入 + 位置编码 + 段嵌入(三者直接相加)。
四、从零实现一个Transformer(PyTorch精简版)
理论讲完了,上代码。我们实现一个完整的Transformer编码器:
import torch import torch.nn as nn import math class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:, :x.size(1), :] class TransformerEncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): # 自注意力 + 残差 + 层归一化 attn_out = self.self_attn(x) x = self.norm1(x + self.dropout(attn_out)) # 前馈网络 + 残差 + 层归一化 ffn_out = self.ffn(x) x = self.norm2(x + self.dropout(ffn_out)) return x class TransformerEncoder(nn.Module): def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model, max_len) self.layers = nn.ModuleList([ TransformerEncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers) ]) def forward(self, x): x = self.embedding(x) x = self.pos_encoding(x) for layer in self.layers: x = layer(x) return x五、Transformer的经典变体
Transformer的编码器和解码器可以单独使用,衍生出不同的模型家族:
| 模型 | 使用部分 | 特点 | 代表模型 |
|---|---|---|---|
| Encoder-only | 仅编码器 | 双向理解,擅长分类、问答 | BERT |
| Decoder-only | 仅解码器 | 单向生成,擅长文本续写 | GPT系列 |
| Encoder-Decoder | 两者都用 | 序列到序列,擅长翻译、摘要 | T5、BART |
BERT用的是Transformer的编码器部分,能同时看到上下文(双向)。GPT用的是解码器部分,只能从左到右看(单向),但正因为这种单向设计,它才能做自回归生成。这也是为什么GPT系列能“涌现”出强大的生成能力。
六、Transformer的应用早已不限于NLP
Transformer最初为机器翻译而生,但如今早已“出圈”:
计算机视觉:ViT(Vision Transformer)将图像切成一个个patch,当成“词”来处理
语音识别:将音频信号转为序列,用Transformer建模
多模态融合:同时处理文本、图像、音频等多种模态
推荐系统:将用户行为序列用Transformer建模
七、总结
Transformer的成功可以归结为三点:
全并行计算:摆脱了RNN的串行束缚,训练效率提升几个数量级
直接建模长距离依赖:任意两个位置都可以直接“对话”,不再有信息衰减
极致的可扩展性:从1亿参数到万亿参数,Transformer都能Hold住
从2017年那篇仅11页的论文,到如今统治整个AI领域的“万能架构”,Transformer的故事远未结束。无论你是想入门AI、准备面试,还是想深入大模型底层原理,吃透Transformer都是绕不开的第一步。
如果这篇文章帮到了你,欢迎点赞、收藏、转发!
