当前位置: 首页 > news >正文

动手学深度学习——BERT代码

1. 前言

上一篇我们已经从整体上理解了BERT

  • 它是基于 Transformer Encoder 的双向预训练语言模型

  • 它采用“预训练 + 微调”的范式

  • 它的核心预训练任务包括:

    • MLM(Masked Language Modeling)

    • NSP(Next Sentence Prediction)

  • 它特别适合各种自然语言理解任务

这一篇就继续按李沐的思路,把 BERT 真正落到代码层面。

这一节最重要的,不是一下子把整个预训练过程全写完,
而是先看清楚 BERT 这个模型本身到底由哪些部分组成:

  • 输入表示怎么构造

  • BERT 编码器层怎么堆叠

  • MLM 预测头怎么接

  • NSP 分类头怎么接

  • 前向传播输出到底有哪些东西

如果一句话概括这一节代码的核心,那就是:

BERT = 输入嵌入层 + 多层 Transformer 编码器 + 预训练任务头


2. BERT 代码要解决什么问题

如果从实现角度看,BERT 这节代码主要解决三件事:

第一,构造输入表示

BERT 输入不是单纯 token embedding,
而是:

  • token embedding

  • segment embedding

  • position embedding

三者相加。

第二,搭建深层 Transformer Encoder

让输入序列经过多层自注意力和前馈网络,得到上下文化表示。

第三,接上预训练任务输出头

包括:

  • MLM 头:预测被 mask 的 token

  • NSP 头:判断两句是否连续

所以这一节并不是直接训练,
而是在把:

BERT 的骨架和输出接口

先搭完整。


3. BERT 的输入为什么比普通模型复杂

前面我们学 RNN、GRU、LSTM、Seq2Seq 时,
输入往往主要是:

  • token 索引

  • 或者 token embedding

但 BERT 不一样。
因为它不仅要处理单句,还常常要处理句对任务,
并且 Transformer 本身没有顺序递推结构,所以还必须显式加入位置信息。

因此,BERT 输入通常由三部分组成:

3.1 Token Embedding

表示词本身是谁。

3.2 Segment Embedding

表示这个 token 属于句子 A 还是句子 B。

3.3 Position Embedding

表示这个 token 在序列中的位置。

所以 BERT 输入层本质上比前面的模型更“结构化”。


4. BERT 输入嵌入层通常怎么写

李沐这里常见的 BERT 编码器初始化,大致会写成这样:

class BERTEncoder(nn.Module): def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, **kwargs): super(BERTEncoder, self).__init__(**kwargs) self.token_embedding = nn.Embedding(vocab_size, num_hiddens) self.segment_embedding = nn.Embedding(2, num_hiddens) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module( f"{i}", d2l.EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, True) ) self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))

这段代码基本就把 BERT 编码器的主体结构全展示出来了。


5. 为什么token_embedding很好理解

这一句:

self.token_embedding = nn.Embedding(vocab_size, num_hiddens)

和前面很多 NLP 模型一样,
作用就是把 token 索引变成稠密向量。

例如:

  • deep

  • learning

  • [CLS]

  • [SEP]

这些 token 本身是整数编号,
经过 embedding 以后,变成:

num_hiddens

维的向量表示。

这是所有输入表示的基础。


6. 为什么segment_embedding只需要 2

这一句:

self.segment_embedding = nn.Embedding(2, num_hiddens)

中的2表示:

只区分两种 segment

通常是:

  • 句子 A

  • 句子 B

例如在句对输入中:

[CLS] 句子A [SEP] 句子B [SEP]

那么:

  • 句子 A 对应 segment id = 0

  • 句子 B 对应 segment id = 1

所以 segment embedding 只需要 2 种类型就够了。

它的作用是告诉模型:

当前 token 属于哪一段句子

这对句对任务很重要。


7. 为什么 BERT 的位置编码这里是可学习参数

这一句:

self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))

表示:

BERT 这里使用的是可学习的位置编码

也就是说,位置向量不是固定公式生成的,
而是直接作为模型参数参与训练。

这和最原始 Transformer 论文里常见的正弦位置编码不同。
BERT 采用可学习位置嵌入,也是很经典的做法。

它的含义很简单:

  • 序列第 1 个位置有一个向量

  • 第 2 个位置有另一个向量

  • 这些位置向量会在训练中自动优化


8. 为什么 BERT 需要多层EncoderBlock

这一段:

self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module(... d2l.EncoderBlock(...))

说明 BERT 编码器并不是一层 Transformer,
而是:

多层 Transformer Encoder Block 堆叠

这和前面讲“深层循环神经网络”很像,
只不过这里堆叠的不是 RNN / GRU / LSTM,
而是:

  • 多头自注意力

  • 残差连接

  • LayerNorm

  • 前馈网络

共同组成的 Transformer 编码块。

所以 BERT 强,并不只是因为它有自注意力,
还因为它是:

多层深度 Transformer 编码器


9. BERT 编码器前向传播怎么写

常见写法如下:

def forward(self, tokens, segments, valid_lens): X = self.token_embedding(tokens) + self.segment_embedding(segments) X = X + self.pos_embedding.data[:, :X.shape[1], :] for blk in self.blks: X = blk(X, valid_lens) return X

这段代码非常关键,因为它把 BERT 编码器的数据流真正串起来了。


10. 为什么一开始要把三种 embedding 相加

这一句的核心是:

X = token_embedding + segment_embedding + position_embedding

这意味着:

每个位置的最终输入表示,不是单一来源,而是三种信息的叠加。

具体来说:

token embedding

告诉模型:这个词本身是谁。

segment embedding

告诉模型:它属于句子 A 还是句子 B。

position embedding

告诉模型:它在整个序列中的哪个位置。

所以 BERT 输入层的目标,就是把:

词信息 + 句段信息 + 位置信息

统一融合成一个向量表示。


11.for blk in self.blks: X = blk(X, valid_lens)在干什么

这表示把输入依次送过多层 Transformer Encoder Block。

每一层都会做:

  • 自注意力

  • 前馈网络

  • 残差连接

  • LayerNorm

经过多层之后,X中每个位置的表示都会变成:

深层上下文化表示

也就是说,到了最后输出时:

  • [CLS]位置表示的是整段综合信息

  • 普通 token 位置表示的是结合上下文后的 token 表示

这正是 BERT 的核心输出。


12. BERT 编码器输出的X到底是什么

最终返回的:

X

形状通常是:

(batch_size, num_steps, num_hiddens)

它表示:

整个输入序列中,每个位置的上下文化表示

注意,这已经不是原始 embedding 了,
而是经过多层 Transformer 编码之后的结果。

这个输出后面会被分别送给:

  • MLM 头

  • NSP 头

  • 或者下游微调任务头

所以 BERT 的主体输出,本质上就是:

高质量上下文化 token 表示序列


13. MLM 预测头为什么单独写一个模块

因为 MLM 的任务不是预测所有位置,
而是只预测那些被选中的 mask 位置。

所以通常会专门写一个MaskLM类,例如:

class MaskLM(nn.Module): def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs): super(MaskLM, self).__init__(**kwargs) self.mlp = nn.Sequential( nn.Linear(num_inputs, num_hiddens), nn.ReLU(), nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, vocab_size) )

这说明 MLM 不是直接拿编码器输出硬做分类,
而是会再经过一个小型预测头。


14. 为什么 MLM 头里还要有 MLP

因为 BERT 主体输出的是上下文化表示,
而 MLM 任务需要的是:

对词表中所有 token 的预测分布

中间加一个小 MLP,有几个好处:

第一,增加表达能力

让 MLM 预测头更灵活,不只是一个线性映射。

第二,更贴近预训练目标

让模型在输出到词表前再做一次非线性变换。

第三,和原始 BERT 结构一致

原版 BERT 的 MLM head 本来也不是最简单单层线性。

所以这里写成一个小的投影头是合理的。


15. MLM 前向传播为什么只取被 mask 的位置

常见写法大致如下:

def forward(self, X, pred_positions): num_pred_positions = pred_positions.shape[1] pred_positions = pred_positions.reshape(-1) batch_size = X.shape[0] batch_idx = torch.arange(0, batch_size) batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions) masked_X = X[batch_idx, pred_positions] masked_X = masked_X.reshape((batch_size, num_pred_positions, -1)) mlm_Y_hat = self.mlp(masked_X) return mlm_Y_hat

这里最关键的是:

只把被 mask 的那些位置挑出来做预测

因为 MLM 的损失只计算这些位置。
没被 mask 的位置不参与 MLM 分类目标。

所以这段代码本质上是在做:

  • 从编码器输出里,按位置索引抽取 mask 位置表示

  • 再送入 MLM 预测头


16. 为什么pred_positions很重要

pred_positions记录的是:

本条样本里哪些位置被选作 MLM 预测目标

例如一句话长度 10,
可能只 mask 了第 2、5、8 个位置。
那么:

pred_positions = [2, 5, 8]

这时 MLM 头就只处理这 3 个位置对应的表示。

所以 BERT 预训练不只是把[MASK]塞进去这么简单,
还必须显式记录:

哪些位置需要参与 MLM loss


17. NSP 头为什么更简单

和 MLM 不同,NSP 是一个句级二分类任务:

判断句子 B 是否是句子 A 的真实后续

所以它只需要拿整段输入的一个综合表示做分类即可。

在 BERT 里,最经典做法就是用:

[CLS]

位置对应的输出表示。

因此 NSP 头通常很简单,例如:

class NextSentencePred(nn.Module): def __init__(self, num_inputs, **kwargs): super(NextSentencePred, self).__init__(**kwargs) self.output = nn.Linear(num_inputs, 2) def forward(self, X): return self.output(X)

这里输出维度为 2,表示:

  • 是下一句

  • 不是下一句


18. 为什么 NSP 通常只看[CLS]表示

因为[CLS]的设计目标本来就是:

汇总整段输入的全局信息

经过多层 Transformer 编码后,
[CLS]位置的表示通常已经融合了整句甚至句对的整体语义。

所以拿它做:

  • 文本分类

  • 句对判断

  • NSP

都很自然。

这也是为什么 BERT 对句级任务特别方便。


19. 最终 BERT 模型怎么组装

李沐这里常见的总模型类,大致会写成这样:

class BERTModel(nn.Module): def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, hid_in_features=768, mlm_in_features=768, nsp_in_features=768): super(BERTModel, self).__init__() self.encoder = BERTEncoder( vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=max_len, key_size=key_size, query_size=query_size, value_size=value_size) self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh()) self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features) self.nsp = NextSentencePred(nsp_in_features)

这就把:

  • BERT 编码器

  • MLM 头

  • NSP 头

全都组装到一起了。


20. 为什么这里还有一个self.hidden

这一句:

self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh())

通常是为了对[CLS]表示先做一次额外变换,
再送入 NSP 分类头。

也就是说:

  • 编码器输出[CLS]表示

  • 经过一个小投影层

  • 再做 NSP 分类

这和原始 BERT 的设计是对应的。
它让句级分类头不只是一个裸线性层,而是有一个中间变换。


21. BERT 总模型前向传播怎么写

常见写法如下:

def forward(self, tokens, segments, valid_lens=None, pred_positions=None): encoded_X = self.encoder(tokens, segments, valid_lens) if pred_positions is not None: mlm_Y_hat = self.mlm(encoded_X, pred_positions) else: mlm_Y_hat = None nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :])) return encoded_X, mlm_Y_hat, nsp_Y_hat

这段代码把 BERT 的整体输出逻辑一下串起来了。


22. 为什么返回三个结果

这一点很关键:

return encoded_X, mlm_Y_hat, nsp_Y_hat

说明 BERT 总模型不是只输出一个东西。

encoded_X

表示整段输入所有位置的上下文化表示。
后续微调任务常常直接用它。

mlm_Y_hat

表示 MLM 任务在 mask 位置上的预测结果。
预训练时要算 MLM loss。

nsp_Y_hat

表示 NSP 二分类结果。
预训练时要算 NSP loss。

所以 BERT 模型主体并不是某个单一任务模型,
而是:

一个共享编码器 + 多任务预训练头


23. 为什么encoded_X[:, 0, :]就是[CLS]表示

因为[CLS]通常被放在输入序列最前面。

所以编码器输出张量:

encoded_X.shape = (batch_size, num_steps, num_hiddens)

中:

encoded_X[:, 0, :]

就表示每个样本第 0 个位置,也就是[CLS]的输出表示。

这正好能作为整句综合表示,用于 NSP 或其他句级任务。


24. 这一节代码最该掌握什么

如果从学习重点来看,最关键的是下面几件事。

24.1 BERT 输入表示由哪三部分组成

  • token embedding

  • segment embedding

  • position embedding

24.2 BERT 主体为什么是多层 Transformer Encoder

这是它上下文化表示能力的来源。

24.3 MLM 头和 NSP 头分别做什么

  • MLM:token 级预测

  • NSP:句级二分类

24.4 为什么 MLM 只取被 mask 的位置

因为预训练目标只在这些位置上计算。

24.5 为什么[CLS]位置输出可用于句级任务

因为它是整段输入的综合表示。


25. 这一节和后面几节怎么衔接

这一节其实只是把BERT 模型结构搭起来。
而你给的目录后面还会继续讲:

  • BERT 预训练数据代码

  • BERT 预训练代码

  • BERT 微调

  • 自然语言推理数据集

  • BERT 微调代码

所以这节可以理解成:

先搭模型本体

后面几节才会逐步补齐:

  • 数据怎么造

  • 损失怎么算

  • 训练怎么跑

  • 下游任务怎么接

这个顺序非常合理。


26. 本节总结

这一节我们学习了 BERT 的代码结构,核心内容可以总结为以下几点。

26.1 BERT 编码器由三部分输入表示和多层 Transformer Encoder 组成

这是模型主体。

26.2 输入表示包括 token、segment 和 position 三类 embedding

三者相加形成最终输入。

26.3 MLM 头负责预测被 mask 的 token

因此只取被 mask 的位置做分类。

26.4 NSP 头负责句级二分类

通常使用[CLS]位置的输出表示。

26.5 整个 BERT 模型本质上是“共享编码器 + 多任务预训练头”

这是它预训练范式的核心。


27. 学习感悟

这一节特别重要,因为它让你第一次真正看到:

BERT 并不是一个“神秘黑箱”,
它其实就是:
Transformer 编码器 + 精心设计的输入表示 + 预训练任务头

很多时候,大家一提 BERT 就觉得它很大、很复杂。
但如果把结构拆开,其实非常清楚:

  • 先把输入表示准备好

  • 用深层自注意力编码

  • 再接两个预训练任务头

真正难的地方,不是它概念多玄,而是它把这些部分组织得特别好。

http://www.jsqmd.com/news/645157/

相关文章:

  • B站视频下载神器BilibiliDown:3步搞定离线观看与批量收藏的完整指南
  • 2026年客服软件哪个易用?实用在线客服系统体验测评指南 - 品牌2026
  • 基础只是:发动机、变速器、地盘、电池、电机、电控、智能座仓、辅助驾驶 / 当代汽车八大件
  • CosyVoice3保姆级使用指南:3秒音频克隆人声,自然语言控制情感
  • 保姆级教程:用AdGuard DNS代理实现全设备广告过滤(含Win/Mac/安卓/iOS配置)
  • 5步掌握WeNet:从零部署到生产级语音识别系统
  • 热力管道保温施工团队哪家实力强?施工能力大比拼 - 品牌推荐大师
  • pkNX宝可梦编辑器完全指南:从零开始定制你的Switch宝可梦游戏
  • 2025届学术党必备的五大AI学术助手实际效果
  • Qt Release版本打包成单文件exe的完整指南(含Enigma Virtual Box配置)
  • PyTorch 2.6 快速上手:基于镜像的深度学习项目实战教程
  • 如何快速掌握开源项目管理:5个核心功能打造高效团队协作空间
  • 【避坑指南】UniApp中getLocation坐标转换的精准定位实践
  • 【行业深度对谈】穿透“文凭焦虑”:翼程教育17年深耕江苏,合规办学助力长三角人才学历突围 - 商业科技观察
  • 2026企业级国产OpenClaw安全合规工具怎么选?推荐开源智能体 - 品牌2025
  • Axure RP中文语言包完全指南:5分钟实现专业界面本地化
  • CCS更换芯片型号必看:避免FLASH memory冲突的3种实用解决方案
  • 苍穹外卖debug篇
  • 从SDK到Vitis:FPGA工程迁移的完整指南与实战技巧
  • 智能体学习20——人类参与环节(Human-in-the-Loop)
  • NVIDIA Profile Inspector深度指南:解锁显卡隐藏性能的专业工具
  • Paimon与Flink CDC实战:从MySQL到实时数据湖的构建
  • 数据结构作业—用队列求解迷宫问题
  • Java异常处理实战:从EduCoder平台到真实项目的避坑指南
  • 突破百度网盘限速封锁:开源解析工具终极使用秘籍
  • WaveTools终极指南:三招提升《鸣潮》游戏体验的完整解决方案
  • 手把手教你用Simulink搭建级联H桥储能变流器仿真模型(附SOC均衡分析)
  • 闲置微信立减金别浪费!安全回收攻略,避开陷阱快速落袋 - 可可收
  • 3步快速解密网易云音乐NCM文件:免费工具完整指南
  • STM32调试接口锁死(No ST-LINK detected)的深度排查与解锁指南