别再死记硬背BERT结构了!用PyTorch手搓一个BERT-Base,带你彻底搞懂MLM和NSP
从零实现BERT-Base:深入解析MLM与NSP的PyTorch实战指南
1. 为什么需要动手实现BERT?
在自然语言处理领域,BERT已经成为基石般的模型架构。但很多开发者发现,仅仅通过调用transformers库来使用BERT,就像驾驶一辆无法打开引擎盖的跑车——你可以踩油门前进,却对内部工作原理一无所知。
理解BERT的核心价值在于:
- 80-10-10掩码策略的巧妙设计如何解决预训练与微调的数据分布差异
- 三种嵌入相加的数学本质及其对位置信息的编码方式
- 注意力头之间的参数共享机制如何影响模型表现
- 层归一化的放置位置为何比Transformer原始论文更有效
当我第一次尝试修改BERT的注意力头大小时,才真正意识到那些看似简单的架构决策背后蕴含的深刻工程智慧。下面让我们用PyTorch从零开始,构建一个完整可训练的BERT-Base模型。
2. 模型架构设计
2.1 嵌入层实现
BERT的嵌入层由三个部分组成,它们的数学表达可以表示为:
$$ \text{Embedding} = \text{TokenEmbedding} + \text{SegmentEmbedding} + \text{PositionEmbedding} $$
class BERTEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.segment_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) token_emb = self.token_embeddings(input_ids) position_emb = self.position_embeddings(position_ids) segment_emb = self.segment_embeddings(token_type_ids) embeddings = token_emb + position_emb + segment_emb embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings关键细节:位置嵌入是可学习的参数而非固定正弦函数,这是BERT与原始Transformer的重要区别
2.2 Transformer编码器层
每个编码器层包含:
- 多头自注意力机制
- 前馈神经网络
- 残差连接和层归一化
class BERTSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.query = nn.Linear(config.hidden_size, config.hidden_size) self.key = nn.Linear(config.hidden_size, config.hidden_size) self.value = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def forward(self, hidden_states, attention_mask=None): batch_size = hidden_states.size(0) # 线性变换 q = self.query(hidden_states) k = self.key(hidden_states) v = self.value(hidden_states) # 多头分割 q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # 注意力分数计算 scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim) if attention_mask is not None: scores = scores + attention_mask # 注意力概率 probs = nn.Softmax(dim=-1)(scores) probs = self.dropout(probs) # 上下文加权 context = torch.matmul(probs, v) context = context.transpose(1, 2).contiguous() context = context.view(batch_size, -1, self.num_heads * self.head_dim) # 输出投影 output = self.dense(context) return output3. 预训练任务实现
3.1 掩码语言模型(MLM)
BERT的MLM任务采用独特的80-10-10策略:
| 处理方式 | 比例 | 示例 (原始句子: "the man ate an apple") |
|---|---|---|
| 替换为[MASK] | 80% | "the man [MASK] an apple" |
| 替换为随机词 | 10% | "the man ran an apple" |
| 保持原词 | 10% | "the man ate an apple" |
def create_masked_lm_predictions(tokens, mask_prob, vocab_size): """生成MLM训练样本""" output_tokens = list(tokens) masked_lm_positions = [] masked_lm_labels = [] for i, token in enumerate(tokens): if token in ["[CLS]", "[SEP]"]: continue prob = random.random() if prob < mask_prob: masked_lm_positions.append(i) mask_decision = random.random() if mask_decision < 0.8: output_tokens[i] = "[MASK]" elif mask_decision < 0.9: output_tokens[i] = random.randint(0, vocab_size-1) # 剩下10%保持原样 masked_lm_labels.append(token) return output_tokens, masked_lm_positions, masked_lm_labels3.2 下一句预测(NSP)
NSP任务的样本构造规则:
def create_next_sentence_predictions(text_a, text_b, max_seq_length): """生成NSP训练样本""" # 50%概率使用真实下一句 if random.random() < 0.5: is_next = True tokens_a = tokenize(text_a) tokens_b = tokenize(text_b) else: is_next = False tokens_a = tokenize(text_a) tokens_b = tokenize(random.choice(corpus)) # 随机选择非关联句子 # 合并并截断 truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) # 添加特殊token tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] segment_ids = [0]*(len(tokens_a)+2) + [1]*(len(tokens_b)+1) return tokens, segment_ids, is_next4. 完整模型整合
将各组件组合成完整BERT模型:
class BERTForPretraining(nn.Module): def __init__(self, config): super().__init__() self.bert = BERTModel(config) self.mlm_head = MaskedLMHead(config) self.nsp_head = NextSentencePredictionHead(config) def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_positions=None): # 获取BERT输出 sequence_output, pooled_output = self.bert( input_ids, token_type_ids, attention_mask) # MLM任务 if masked_lm_positions is not None: masked_lm_output = torch.gather( sequence_output, 1, masked_lm_positions.unsqueeze(-1).expand(-1,-1,sequence_output.size(-1))) mlm_scores = self.mlm_head(masked_lm_output) else: mlm_scores = None # NSP任务 nsp_scores = self.nsp_head(pooled_output) return mlm_scores, nsp_scores5. 训练技巧与优化
5.1 动态掩码策略
原始BERT在数据预处理时生成掩码,更高效的做法是在训练时动态生成:
class DynamicMasking: def __init__(self, mask_prob=0.15): self.mask_prob = mask_prob def apply(self, batch): masked_batch = batch.clone() labels = torch.full_like(batch, -100) # 忽略非掩码位置 # 为每个序列生成随机掩码 rand = torch.rand(batch.shape) mask_pos = (rand < self.mask_prob) & (batch != 0) # 忽略padding # 80-10-10策略 mask_decision = torch.rand(batch.shape) masked_batch[mask_pos & (mask_decision < 0.8)] = tokenizer.mask_token_id random_words = torch.randint(0, tokenizer.vocab_size, batch.shape) masked_batch[mask_pos & (mask_decision >= 0.8) & (mask_decision < 0.9)] = ( random_words[mask_pos & (mask_decision >= 0.8) & (mask_decision < 0.9)]) labels[mask_pos] = batch[mask_pos] return masked_batch, labels5.2 梯度累积
当GPU内存不足时,可以使用梯度累积模拟更大batch size:
accumulation_steps = 4 optimizer.zero_grad() for i, batch in enumerate(dataloader): loss = model(batch).mean() loss.backward() if (i+1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()6. 性能优化技巧
6.1 混合精度训练
使用AMP(Automatic Mixed Precision)加速训练:
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): loss = model(batch) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()6.2 注意力优化
实现内存高效的注意力计算:
def memory_efficient_attention(q, k, v, mask=None): """分块计算注意力以减少内存占用""" chunk_size = 64 # 根据GPU内存调整 scores = torch.einsum('bhid,bhjd->bhij', q, k) / math.sqrt(q.size(-1)) if mask is not None: scores = scores + mask probs = torch.softmax(scores, dim=-1) # 分块计算 output = torch.zeros_like(v) for i in range(0, q.size(2), chunk_size): chunk = torch.einsum('bhij,bhjd->bhid', probs[:,:,i:i+chunk_size], v[:,:,i:i+chunk_size]) output[:,:,i:i+chunk_size] = chunk return output7. 模型部署实践
7.1 权重共享技巧
# 在初始化时共享权重 self.mlm_head.dense.weight = self.bert.embeddings.token_embeddings.weight7.2 ONNX导出
将模型导出为ONNX格式以便生产环境部署:
torch.onnx.export( model, (dummy_input,), "bert.onnx", input_names=["input_ids", "attention_mask"], output_names=["output"], dynamic_axes={ "input_ids": {0: "batch", 1: "sequence"}, "attention_mask": {0: "batch", 1: "sequence"}, "output": {0: "batch"} } )