用PyTorch手写一个Transformer的Encoder:从理论到代码的保姆级实践
用PyTorch手写一个Transformer的Encoder:从理论到代码的保姆级实践
在自然语言处理领域,Transformer架构已经成为事实上的标准。本文将带你从零开始,用PyTorch实现一个完整的Transformer Encoder模块。不同于简单的API调用,我们会深入每个组件的实现细节,让你真正理解其工作原理。
1. 环境准备与基础架构
首先确保你的环境安装了PyTorch 1.8+版本。我们将从最基本的模块开始构建:
import torch import torch.nn as nn import torch.nn.functional as F import math class TransformerEncoder(nn.Module): def __init__(self, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1): super().__init__() self.layers = nn.ModuleList([ EncoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers) ]) def forward(self, src, src_mask=None, src_key_padding_mask=None): output = src for layer in self.layers: output = layer(output, src_mask, src_key_padding_mask) return output这个基础框架定义了Encoder的核心参数:
d_model: 模型维度(默认512)nhead: 注意力头数(默认8)num_layers: Encoder层数(默认6)dim_feedforward: 前馈网络维度(默认2048)dropout: Dropout率(默认0.1)
2. 实现多头自注意力机制
多头注意力是Transformer的核心组件,让我们分解实现:
2.1 注意力计算基础
def scaled_dot_product_attention(q, k, v, mask=None, dropout=None): d_k = q.size(-1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, v), p_attn关键点说明:
- 缩放因子
math.sqrt(d_k)防止点积结果过大 masked_fill处理padding和序列mask- 最终返回注意力权重和加权值
2.2 多头注意力实现
class MultiHeadAttention(nn.Module): def __init__(self, d_model, nhead, dropout=0.1): super().__init__() assert d_model % nhead == 0 self.d_k = d_model // nhead self.nhead = nhead self.q_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.out = nn.Linear(d_model, d_model) def forward(self, q, k, v, mask=None): batch_size = q.size(0) # 线性变换并分头 q = self.q_linear(q).view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2) k = self.k_linear(k).view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2) v = self.v_linear(v).view(batch_size, -1, self.nhead, self.d_k).transpose(1, 2) # 计算注意力 scores, attn = scaled_dot_product_attention(q, k, v, mask, self.dropout) # 合并多头 concat = scores.transpose(1, 2).contiguous() \ .view(batch_size, -1, self.nhead * self.d_k) return self.out(concat)实现细节:
- 每个头处理
d_model//nhead维度的子空间 - 线性变换后reshape实现分头
- 最后合并多头输出并通过线性层
3. 前馈网络与残差连接
3.1 位置前馈网络
class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ff, d_model) def forward(self, x): return self.linear2(self.dropout(F.relu(self.linear1(x))))这个简单的两层网络实现了:
- 扩展维度(d_model → d_ff)
- ReLU激活
- Dropout正则化
- 降维回d_model
3.2 完整Encoder层实现
class EncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, nhead, dropout) self.ffn = PositionwiseFeedForward(d_model, dim_feedforward, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, src, src_mask=None, src_key_padding_mask=None): # 自注意力子层 src2 = self.self_attn(src, src, src, src_key_padding_mask) src = src + self.dropout1(src2) src = self.norm1(src) # 前馈子层 src2 = self.ffn(src) src = src + self.dropout2(src2) src = self.norm2(src) return src关键设计:
- 残差连接后接LayerNorm(Post-Norm结构)
- 每个子层都有独立的Dropout
- 支持padding mask处理
4. 位置编码与完整模型集成
4.1 正弦位置编码实现
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:, :x.size(1)]这个实现:
- 使用正弦/余弦函数生成位置编码
- 频率随维度增加而降低
- 最终与输入相加
4.2 完整Transformer Encoder
class TransformerEncoderModel(nn.Module): def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, max_len=5000): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoder = PositionalEncoding(d_model, max_len) self.encoder = TransformerEncoder(d_model, nhead, num_layers, dim_feedforward, dropout) self.d_model = d_model def forward(self, src, src_mask=None, src_key_padding_mask=None): src = self.embedding(src) * math.sqrt(self.d_model) src = self.pos_encoder(src) return self.encoder(src, src_mask, src_key_padding_mask)使用示例:
model = TransformerEncoderModel(vocab_size=10000) src = torch.randint(0, 10000, (32, 100)) # batch=32, seq_len=100 output = model(src)5. 调试技巧与性能优化
5.1 常见问题排查
| 问题现象 | 可能原因 | 解决方案 |
|---|---|---|
| NaN损失 | 梯度爆炸 | 检查缩放因子,添加梯度裁剪 |
| 训练缓慢 | 头维度不合理 | 确保d_model能被nhead整除 |
| 过拟合 | 缺乏正则化 | 调整dropout率,增加层归一化 |
5.2 性能优化技巧
- 内存优化:
# 使用checkpoint减少内存占用 from torch.utils.checkpoint import checkpoint output = checkpoint(layer, src)- 混合精度训练:
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): output = model(src) loss = criterion(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()- 批处理优化:
# 使用pad_sequence处理变长输入 from torch.nn.utils.rnn import pad_sequence padded = pad_sequence(sequences, batch_first=True)6. 实际应用示例:文本分类
让我们用实现的Encoder构建一个文本分类器:
class TextClassifier(nn.Module): def __init__(self, vocab_size, num_classes, d_model=256, nhead=8, num_layers=4): super().__init__() self.encoder = TransformerEncoderModel(vocab_size, d_model, nhead, num_layers) self.classifier = nn.Linear(d_model, num_classes) def forward(self, src): encoded = self.encoder(src) # [batch, seq, d_model] pooled = encoded.mean(dim=1) # 平均池化 return self.classifier(pooled)训练技巧:
- 使用学习率预热
- 逐步增加序列长度
- 监控注意力模式
# 学习率调度示例 def lr_schedule(step, d_model=256, warmup=4000): arg1 = step ** -0.5 arg2 = step * (warmup ** -1.5) return (d_model ** -0.5) * min(arg1, arg2)7. 高级话题:自定义注意力变体
7.1 相对位置注意力
class RelativePositionAttention(nn.Module): def __init__(self, d_model, nhead, max_relative_pos=16): super().__init__() self.d_k = d_model // nhead self.nhead = nhead self.max_relative_pos = max_relative_pos # 相对位置嵌入 self.relative_pos_emb = nn.Parameter( torch.randn(2 * max_relative_pos - 1, self.d_k)) def forward(self, q, k, v): batch_size, seq_len = q.size(0), q.size(1) # 计算相对位置索引 range_vec = torch.arange(seq_len) distance_mat = range_vec[:, None] - range_vec[None, :] distance_mat_clipped = torch.clamp( distance_mat, -self.max_relative_pos, self.max_relative_pos) final_mat = distance_mat_clipped + self.max_relative_pos - 1 # 获取相对位置嵌入 relative_pos_emb = self.relative_pos_emb[final_mat] # 合并到注意力计算 # ... (完整实现需要考虑q与相对位置的交互)7.2 稀疏注意力模式
def sparse_attention_mask(seq_len, window_size=32): mask = torch.ones(seq_len, seq_len) for i in range(seq_len): start = max(0, i - window_size // 2) end = min(seq_len, i + window_size // 2) mask[i, start:end] = 0 return mask.bool()这些高级变体可以:
- 处理更长序列
- 捕捉局部依赖关系
- 减少计算复杂度
实现一个完整的Transformer Encoder不仅加深了对理论的理解,也为自定义架构打下了坚实基础。在实际项目中,你可以基于这个实现继续扩展,比如添加Adapter层、实现不同的注意力机制,或者与其他模块集成。
