PyTorch 2.0+ 实现 Transformer:6层编码器/解码器在 WMT14 数据集上的完整训练流程
PyTorch 2.0+ 实现 Transformer:6层编码器/解码器在 WMT14 数据集上的完整训练流程
Transformer 架构自 2017 年提出以来,已成为自然语言处理领域的基石模型。本文将深入探讨如何使用 PyTorch 2.0+ 实现一个完整的 Transformer 模型,并在 WMT14 英德翻译数据集上进行训练。不同于简单的玩具实现,我们将重点关注工业级的数据流水线构建、混合精度训练和梯度累积等高级技巧。
1. 环境准备与数据预处理
1.1 安装依赖与配置
首先确保已安装 PyTorch 2.0+ 和必要的依赖库:
pip install torch torchtext torchdata sacrebleu tensorboard对于 GPU 加速训练,建议使用 CUDA 11.7+ 版本。我们可以通过以下代码检查环境配置:
import torch print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") print(f"CUDA version: {torch.version.cuda}")1.2 WMT14 数据集加载
WMT14 是机器翻译领域的标准基准数据集,包含约 450 万句英德平行语料。我们将使用 torchtext 提供的 API 进行加载:
from torchtext.datasets import WMT14 from torchtext.data.utils import get_tokenizer SRC_LANGUAGE = 'de' TGT_LANGUAGE = 'en' # 加载分词器 token_transform = { SRC_LANGUAGE: get_tokenizer('spacy', language='de_core_news_sm'), TGT_LANGUAGE: get_tokenizer('spacy', language='en_core_web_sm') } # 构建词汇表 def build_vocab(filepaths, tokenizer, min_freq=2): counter = Counter() for filepath in filepaths: with open(filepath, 'r', encoding='utf-8') as f: for line in f: counter.update(tokenizer(line)) return Vocab(counter, min_freq=min_freq, specials=['<unk>', '<pad>', '<bos>', '<eos>']) train_iter = WMT14(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE)) vocab_transform = { SRC_LANGUAGE: build_vocab([train_iter.src], token_transform[SRC_LANGUAGE]), TGT_LANGUAGE: build_vocab([train_iter.tgt], token_transform[TGT_LANGUAGE]) }提示:在实际项目中,建议预先处理好数据集并保存到本地,避免每次训练都重新处理。
1.3 数据流水线优化
为了高效加载数据,我们实现一个自定义的 Dataset 和 DataLoader:
from torch.utils.data import Dataset, DataLoader class TranslationDataset(Dataset): def __init__(self, src_sentences, tgt_sentences, src_vocab, tgt_vocab): self.src_sentences = src_sentences self.tgt_sentences = tgt_sentences self.src_vocab = src_vocab self.tgt_vocab = tgt_vocab def __len__(self): return len(self.src_sentences) def __getitem__(self, idx): src_sentence = self.src_sentences[idx] tgt_sentence = self.tgt_sentences[idx] src_tensor = torch.tensor([self.src_vocab[token] for token in token_transform[SRC_LANGUAGE](src_sentence)]) tgt_tensor = torch.tensor([self.tgt_vocab[token] for token in token_transform[TGT_LANGUAGE](tgt_sentence)]) return src_tensor, tgt_tensor def collate_fn(batch): src_batch, tgt_batch = zip(*batch) src_batch = pad_sequence(src_batch, padding_value=PAD_IDX) tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX) return src_batch, tgt_batch train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)2. Transformer 模型实现
2.1 模型参数配置
我们首先定义 Transformer 的核心参数:
class TransformerConfig: def __init__(self): self.d_model = 512 # 嵌入维度 self.nhead = 8 # 注意力头数 self.num_encoder_layers = 6 # 编码器层数 self.num_decoder_layers = 6 # 解码器层数 self.dim_feedforward = 2048 # 前馈网络维度 self.dropout = 0.1 # Dropout概率 self.activation = 'relu' # 激活函数 self.max_seq_length = 100 # 最大序列长度 self.src_vocab_size = len(vocab_transform[SRC_LANGUAGE]) self.tgt_vocab_size = len(vocab_transform[TGT_LANGUAGE])2.2 位置编码实现
位置编码是 Transformer 的关键组件,用于注入序列的位置信息:
class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x)2.3 完整 Transformer 实现
基于 PyTorch 的 nn.Transformer 模块,我们可以构建完整的模型:
class TransformerModel(nn.Module): def __init__(self, config): super().__init__() self.config = config self.src_tok_emb = nn.Embedding(config.src_vocab_size, config.d_model) self.tgt_tok_emb = nn.Embedding(config.tgt_vocab_size, config.d_model) self.positional_encoding = PositionalEncoding(config.d_model, config.dropout) self.transformer = nn.Transformer( d_model=config.d_model, nhead=config.nhead, num_encoder_layers=config.num_encoder_layers, num_decoder_layers=config.num_decoder_layers, dim_feedforward=config.dim_feedforward, dropout=config.dropout, activation=config.activation ) self.fc_out = nn.Linear(config.d_model, config.tgt_vocab_size) def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): src_emb = self.positional_encoding(self.src_tok_emb(src)) tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt)) output = self.transformer( src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask ) return self.fc_out(output)3. 训练流程优化
3.1 混合精度训练
PyTorch 的 AMP (Automatic Mixed Precision) 可以显著加速训练并减少显存占用:
from torch.cuda.amp import GradScaler, autocast scaler = GradScaler() def train_step(model, optimizer, criterion, src, tgt): model.train() optimizer.zero_grad() tgt_input = tgt[:-1, :] tgt_output = tgt[1:, :] with autocast(): output = model(src, tgt_input) loss = criterion(output.view(-1, output.size(-1)), tgt_output.view(-1)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return loss.item()3.2 梯度累积
对于大 batch size 训练,可以使用梯度累积技术:
accumulation_steps = 4 def train_epoch(model, optimizer, criterion, train_loader): model.train() total_loss = 0 optimizer.zero_grad() for i, (src, tgt) in enumerate(train_loader): src = src.to(device) tgt = tgt.to(device) tgt_input = tgt[:-1, :] tgt_output = tgt[1:, :] with autocast(): output = model(src, tgt_input) loss = criterion(output.view(-1, output.size(-1)), tgt_output.view(-1)) loss = loss / accumulation_steps scaler.scale(loss).backward() if (i + 1) % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() total_loss += loss.item() * accumulation_steps return total_loss / len(train_loader)3.3 学习率调度
使用余弦退火学习率调度器:
from torch.optim.lr_scheduler import CosineAnnealingLR optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.98), eps=1e-9) scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5)4. 评估与结果分析
4.1 BLEU 分数评估
使用 sacreBLEU 进行自动评估:
from sacrebleu import corpus_bleu def evaluate(model, val_loader, max_len=100): model.eval() translations = [] references = [] with torch.no_grad(): for src, tgt in val_loader: src = src.to(device) tgt = tgt.to(device) # 使用贪心解码生成翻译 translation = greedy_decode(model, src, max_len) translations.append(translation) references.append(tgt.cpu().numpy()) bleu_score = corpus_bleu(translations, references) return bleu_score.score def greedy_decode(model, src, max_len): memory = model.encode(src) ys = torch.ones(1, 1).fill_(BOS_IDX).type_as(src.data) for i in range(max_len-1): out = model.decode(memory, ys) prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_word = next_word.item() ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) if next_word == EOS_IDX: break return ys4.2 训练监控
使用 TensorBoard 记录训练过程:
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() for epoch in range(30): train_loss = train_epoch(model, optimizer, criterion, train_loader) val_bleu = evaluate(model, val_loader) writer.add_scalar('Loss/train', train_loss, epoch) writer.add_scalar('BLEU/val', val_bleu, epoch) scheduler.step() print(f'Epoch: {epoch+1:02d} | Train Loss: {train_loss:.3f} | Val BLEU: {val_bleu:.2f}')5. 高级优化技巧
5.1 标签平滑
标签平滑可以防止模型对训练数据过度自信:
class LabelSmoothingLoss(nn.Module): def __init__(self, classes, padding_idx, smoothing=0.1): super().__init__() self.criterion = nn.KLDivLoss(reduction='sum') self.padding_idx = padding_idx self.confidence = 1.0 - smoothing self.smoothing = smoothing self.classes = classes self.true_dist = None def forward(self, x, target): assert x.size(1) == self.classes true_dist = x.data.clone() true_dist.fill_(self.smoothing / (self.classes - 2)) true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) true_dist[:, self.padding_idx] = 0 mask = torch.nonzero(target.data == self.padding_idx) if mask.dim() > 0: true_dist.index_fill_(0, mask.squeeze(), 0.0) self.true_dist = true_dist return self.criterion(x, true_dist) criterion = LabelSmoothingLoss(len(vocab_transform[TGT_LANGUAGE]), PAD_IDX)5.2 模型并行与数据并行
对于大型 Transformer 模型,可以使用分布式训练:
import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def train_distributed(rank, world_size): setup(rank, world_size) model = TransformerModel(config).to(rank) model = DDP(model, device_ids=[rank]) optimizer = torch.optim.Adam(model.parameters()) # 训练循环... cleanup()5.3 模型量化与优化
训练完成后,可以对模型进行量化以提升推理速度:
quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 )6. 实际部署建议
在生产环境中部署 Transformer 模型时,建议考虑以下优化:
- 使用 TorchScript 导出模型:
scripted_model = torch.jit.script(model) scripted_model.save("transformer_scripted.pt")- 实现高效的批处理推理:
from torch.utils.data import DataLoader from concurrent.futures import ThreadPoolExecutor class InferencePipeline: def __init__(self, model_path, batch_size=32, max_workers=4): self.model = torch.jit.load(model_path) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.batch_size = batch_size def process_batch(self, batch): with torch.no_grad(): return self.model(batch) async def predict(self, inputs): batches = [inputs[i:i+self.batch_size] for i in range(0, len(inputs), self.batch_size)] results = list(self.executor.map(self.process_batch, batches)) return torch.cat(results)