动手学深度学习——BERT代码
1. 前言
上一篇我们已经从整体上理解了BERT:
它是基于 Transformer Encoder 的双向预训练语言模型
它采用“预训练 + 微调”的范式
它的核心预训练任务包括:
MLM(Masked Language Modeling)
NSP(Next Sentence Prediction)
它特别适合各种自然语言理解任务
这一篇就继续按李沐的思路,把 BERT 真正落到代码层面。
这一节最重要的,不是一下子把整个预训练过程全写完,
而是先看清楚 BERT 这个模型本身到底由哪些部分组成:
输入表示怎么构造
BERT 编码器层怎么堆叠
MLM 预测头怎么接
NSP 分类头怎么接
前向传播输出到底有哪些东西
如果一句话概括这一节代码的核心,那就是:
BERT = 输入嵌入层 + 多层 Transformer 编码器 + 预训练任务头
2. BERT 代码要解决什么问题
如果从实现角度看,BERT 这节代码主要解决三件事:
第一,构造输入表示
BERT 输入不是单纯 token embedding,
而是:
token embedding
segment embedding
position embedding
三者相加。
第二,搭建深层 Transformer Encoder
让输入序列经过多层自注意力和前馈网络,得到上下文化表示。
第三,接上预训练任务输出头
包括:
MLM 头:预测被 mask 的 token
NSP 头:判断两句是否连续
所以这一节并不是直接训练,
而是在把:
BERT 的骨架和输出接口
先搭完整。
3. BERT 的输入为什么比普通模型复杂
前面我们学 RNN、GRU、LSTM、Seq2Seq 时,
输入往往主要是:
token 索引
或者 token embedding
但 BERT 不一样。
因为它不仅要处理单句,还常常要处理句对任务,
并且 Transformer 本身没有顺序递推结构,所以还必须显式加入位置信息。
因此,BERT 输入通常由三部分组成:
3.1 Token Embedding
表示词本身是谁。
3.2 Segment Embedding
表示这个 token 属于句子 A 还是句子 B。
3.3 Position Embedding
表示这个 token 在序列中的位置。
所以 BERT 输入层本质上比前面的模型更“结构化”。
4. BERT 输入嵌入层通常怎么写
李沐这里常见的 BERT 编码器初始化,大致会写成这样:
class BERTEncoder(nn.Module): def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, **kwargs): super(BERTEncoder, self).__init__(**kwargs) self.token_embedding = nn.Embedding(vocab_size, num_hiddens) self.segment_embedding = nn.Embedding(2, num_hiddens) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module( f"{i}", d2l.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, True) ) self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))这段代码基本就把 BERT 编码器的主体结构全展示出来了。
5. 为什么token_embedding很好理解
这一句:
self.token_embedding = nn.Embedding(vocab_size, num_hiddens)和前面很多 NLP 模型一样,
作用就是把 token 索引变成稠密向量。
例如:
deeplearning[CLS][SEP]
这些 token 本身是整数编号,
经过 embedding 以后,变成:
num_hiddens维的向量表示。
这是所有输入表示的基础。
6. 为什么segment_embedding只需要 2
这一句:
self.segment_embedding = nn.Embedding(2, num_hiddens)中的2表示:
只区分两种 segment
通常是:
句子 A
句子 B
例如在句对输入中:
[CLS] 句子A [SEP] 句子B [SEP]那么:
句子 A 对应 segment id = 0
句子 B 对应 segment id = 1
所以 segment embedding 只需要 2 种类型就够了。
它的作用是告诉模型:
当前 token 属于哪一段句子
这对句对任务很重要。
7. 为什么 BERT 的位置编码这里是可学习参数
这一句:
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))表示:
BERT 这里使用的是可学习的位置编码
也就是说,位置向量不是固定公式生成的,
而是直接作为模型参数参与训练。
这和最原始 Transformer 论文里常见的正弦位置编码不同。
BERT 采用可学习位置嵌入,也是很经典的做法。
它的含义很简单:
序列第 1 个位置有一个向量
第 2 个位置有另一个向量
…
这些位置向量会在训练中自动优化
8. 为什么 BERT 需要多层EncoderBlock
这一段:
self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module(... d2l.EncoderBlock(...))说明 BERT 编码器并不是一层 Transformer,
而是:
多层 Transformer Encoder Block 堆叠
这和前面讲“深层循环神经网络”很像,
只不过这里堆叠的不是 RNN / GRU / LSTM,
而是:
多头自注意力
残差连接
LayerNorm
前馈网络
共同组成的 Transformer 编码块。
所以 BERT 强,并不只是因为它有自注意力,
还因为它是:
多层深度 Transformer 编码器
9. BERT 编码器前向传播怎么写
常见写法如下:
def forward(self, tokens, segments, valid_lens): X = self.token_embedding(tokens) + self.segment_embedding(segments) X = X + self.pos_embedding.data[:, :X.shape[1], :] for blk in self.blks: X = blk(X, valid_lens) return X这段代码非常关键,因为它把 BERT 编码器的数据流真正串起来了。
10. 为什么一开始要把三种 embedding 相加
这一句的核心是:
X = token_embedding + segment_embedding + position_embedding这意味着:
每个位置的最终输入表示,不是单一来源,而是三种信息的叠加。
具体来说:
token embedding
告诉模型:这个词本身是谁。
segment embedding
告诉模型:它属于句子 A 还是句子 B。
position embedding
告诉模型:它在整个序列中的哪个位置。
所以 BERT 输入层的目标,就是把:
词信息 + 句段信息 + 位置信息
统一融合成一个向量表示。
11.for blk in self.blks: X = blk(X, valid_lens)在干什么
这表示把输入依次送过多层 Transformer Encoder Block。
每一层都会做:
自注意力
前馈网络
残差连接
LayerNorm
经过多层之后,X中每个位置的表示都会变成:
深层上下文化表示
也就是说,到了最后输出时:
[CLS]位置表示的是整段综合信息普通 token 位置表示的是结合上下文后的 token 表示
这正是 BERT 的核心输出。
12. BERT 编码器输出的X到底是什么
最终返回的:
X形状通常是:
(batch_size, num_steps, num_hiddens)它表示:
整个输入序列中,每个位置的上下文化表示
注意,这已经不是原始 embedding 了,
而是经过多层 Transformer 编码之后的结果。
这个输出后面会被分别送给:
MLM 头
NSP 头
或者下游微调任务头
所以 BERT 的主体输出,本质上就是:
高质量上下文化 token 表示序列
13. MLM 预测头为什么单独写一个模块
因为 MLM 的任务不是预测所有位置,
而是只预测那些被选中的 mask 位置。
所以通常会专门写一个MaskLM类,例如:
class MaskLM(nn.Module): def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs): super(MaskLM, self).__init__(**kwargs) self.mlp = nn.Sequential( nn.Linear(num_inputs, num_hiddens), nn.ReLU(), nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, vocab_size) )这说明 MLM 不是直接拿编码器输出硬做分类,
而是会再经过一个小型预测头。
14. 为什么 MLM 头里还要有 MLP
因为 BERT 主体输出的是上下文化表示,
而 MLM 任务需要的是:
对词表中所有 token 的预测分布
中间加一个小 MLP,有几个好处:
第一,增加表达能力
让 MLM 预测头更灵活,不只是一个线性映射。
第二,更贴近预训练目标
让模型在输出到词表前再做一次非线性变换。
第三,和原始 BERT 结构一致
原版 BERT 的 MLM head 本来也不是最简单单层线性。
所以这里写成一个小的投影头是合理的。
15. MLM 前向传播为什么只取被 mask 的位置
常见写法大致如下:
def forward(self, X, pred_positions): num_pred_positions = pred_positions.shape[1] pred_positions = pred_positions.reshape(-1) batch_size = X.shape[0] batch_idx = torch.arange(0, batch_size) batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions) masked_X = X[batch_idx, pred_positions] masked_X = masked_X.reshape((batch_size, num_pred_positions, -1)) mlm_Y_hat = self.mlp(masked_X) return mlm_Y_hat这里最关键的是:
只把被 mask 的那些位置挑出来做预测
因为 MLM 的损失只计算这些位置。
没被 mask 的位置不参与 MLM 分类目标。
所以这段代码本质上是在做:
从编码器输出里,按位置索引抽取 mask 位置表示
再送入 MLM 预测头
16. 为什么pred_positions很重要
pred_positions记录的是:
本条样本里哪些位置被选作 MLM 预测目标
例如一句话长度 10,
可能只 mask 了第 2、5、8 个位置。
那么:
pred_positions = [2, 5, 8]这时 MLM 头就只处理这 3 个位置对应的表示。
所以 BERT 预训练不只是把[MASK]塞进去这么简单,
还必须显式记录:
哪些位置需要参与 MLM loss
17. NSP 头为什么更简单
和 MLM 不同,NSP 是一个句级二分类任务:
判断句子 B 是否是句子 A 的真实后续
所以它只需要拿整段输入的一个综合表示做分类即可。
在 BERT 里,最经典做法就是用:
[CLS]位置对应的输出表示。
因此 NSP 头通常很简单,例如:
class NextSentencePred(nn.Module): def __init__(self, num_inputs, **kwargs): super(NextSentencePred, self).__init__(**kwargs) self.output = nn.Linear(num_inputs, 2) def forward(self, X): return self.output(X)这里输出维度为 2,表示:
是下一句
不是下一句
18. 为什么 NSP 通常只看[CLS]表示
因为[CLS]的设计目标本来就是:
汇总整段输入的全局信息
经过多层 Transformer 编码后,[CLS]位置的表示通常已经融合了整句甚至句对的整体语义。
所以拿它做:
文本分类
句对判断
NSP
都很自然。
这也是为什么 BERT 对句级任务特别方便。
19. 最终 BERT 模型怎么组装
李沐这里常见的总模型类,大致会写成这样:
class BERTModel(nn.Module): def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, hid_in_features=768, mlm_in_features=768, nsp_in_features=768): super(BERTModel, self).__init__() self.encoder = BERTEncoder( vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=max_len, key_size=key_size, query_size=query_size, value_size=value_size) self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh()) self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features) self.nsp = NextSentencePred(nsp_in_features)这就把:
BERT 编码器
MLM 头
NSP 头
全都组装到一起了。
20. 为什么这里还有一个self.hidden
这一句:
self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())通常是为了对[CLS]表示先做一次额外变换,
再送入 NSP 分类头。
也就是说:
编码器输出
[CLS]表示经过一个小投影层
再做 NSP 分类
这和原始 BERT 的设计是对应的。
它让句级分类头不只是一个裸线性层,而是有一个中间变换。
21. BERT 总模型前向传播怎么写
常见写法如下:
def forward(self, tokens, segments, valid_lens=None, pred_positions=None): encoded_X = self.encoder(tokens, segments, valid_lens) if pred_positions is not None: mlm_Y_hat = self.mlm(encoded_X, pred_positions) else: mlm_Y_hat = None nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :])) return encoded_X, mlm_Y_hat, nsp_Y_hat这段代码把 BERT 的整体输出逻辑一下串起来了。
22. 为什么返回三个结果
这一点很关键:
return encoded_X, mlm_Y_hat, nsp_Y_hat说明 BERT 总模型不是只输出一个东西。
encoded_X
表示整段输入所有位置的上下文化表示。
后续微调任务常常直接用它。
mlm_Y_hat
表示 MLM 任务在 mask 位置上的预测结果。
预训练时要算 MLM loss。
nsp_Y_hat
表示 NSP 二分类结果。
预训练时要算 NSP loss。
所以 BERT 模型主体并不是某个单一任务模型,
而是:
一个共享编码器 + 多任务预训练头
23. 为什么encoded_X[:, 0, :]就是[CLS]表示
因为[CLS]通常被放在输入序列最前面。
所以编码器输出张量:
encoded_X.shape = (batch_size, num_steps, num_hiddens)中:
encoded_X[:, 0, :]就表示每个样本第 0 个位置,也就是[CLS]的输出表示。
这正好能作为整句综合表示,用于 NSP 或其他句级任务。
24. 这一节代码最该掌握什么
如果从学习重点来看,最关键的是下面几件事。
24.1 BERT 输入表示由哪三部分组成
token embedding
segment embedding
position embedding
24.2 BERT 主体为什么是多层 Transformer Encoder
这是它上下文化表示能力的来源。
24.3 MLM 头和 NSP 头分别做什么
MLM:token 级预测
NSP:句级二分类
24.4 为什么 MLM 只取被 mask 的位置
因为预训练目标只在这些位置上计算。
24.5 为什么[CLS]位置输出可用于句级任务
因为它是整段输入的综合表示。
25. 这一节和后面几节怎么衔接
这一节其实只是把BERT 模型结构搭起来。
而你给的目录后面还会继续讲:
BERT 预训练数据代码
BERT 预训练代码
BERT 微调
自然语言推理数据集
BERT 微调代码
所以这节可以理解成:
先搭模型本体
后面几节才会逐步补齐:
数据怎么造
损失怎么算
训练怎么跑
下游任务怎么接
这个顺序非常合理。
26. 本节总结
这一节我们学习了 BERT 的代码结构,核心内容可以总结为以下几点。
26.1 BERT 编码器由三部分输入表示和多层 Transformer Encoder 组成
这是模型主体。
26.2 输入表示包括 token、segment 和 position 三类 embedding
三者相加形成最终输入。
26.3 MLM 头负责预测被 mask 的 token
因此只取被 mask 的位置做分类。
26.4 NSP 头负责句级二分类
通常使用[CLS]位置的输出表示。
26.5 整个 BERT 模型本质上是“共享编码器 + 多任务预训练头”
这是它预训练范式的核心。
27. 学习感悟
这一节特别重要,因为它让你第一次真正看到:
BERT 并不是一个“神秘黑箱”,
它其实就是:
Transformer 编码器 + 精心设计的输入表示 + 预训练任务头。
很多时候,大家一提 BERT 就觉得它很大、很复杂。
但如果把结构拆开,其实非常清楚:
先把输入表示准备好
用深层自注意力编码
再接两个预训练任务头
真正难的地方,不是它概念多玄,而是它把这些部分组织得特别好。
