从零构建大语言模型:深入理解Transformer架构与PyTorch实践
1. 从零开始理解大语言模型:为什么我们需要亲手搭建?
如果你和我一样,对ChatGPT、Claude这些大语言模型(LLM)的涌现感到既兴奋又困惑,那么“从零开始搭建”这个想法可能不止一次在你脑海中闪过。兴奋的是,这些模型展现出的理解和生成能力,正在重塑我们与技术交互的方式;困惑的是,它们内部仿佛一个黑箱,动辄千亿的参数、复杂的Transformer架构,让人望而却步。市面上充斥着各种调用API的教程,告诉你如何“使用”LLM,但关于其核心原理和构建过程的深度内容却相对稀少。这就像只学会了开车,却对发动机如何工作一无所知。
这正是Sebastian Raschka的“LLM Workshop 2024”以及其著作《Build a Large Language Model From Scratch》的价值所在。这个项目不是一个简单的“Hello World”式演示,而是一次深入骨髓的解剖实验。它基于PyTorch,引导你从最基础的数据处理开始,一步步编码实现一个类GPT的模型架构,并完成预训练和微调的全流程。其核心目的不是让你造出一个媲美GPT-4的模型——那需要天文数字的算力和数据——而是让你彻底理解构成现代LLM的每一个“乐高积木”:Tokenizer如何工作、Attention机制如何实现、训练循环如何组织。当你亲手用代码将这些模块组装起来,并看到它开始生成哪怕是最简单的文本时,你对LLM的理解将发生质变。
无论你是机器学习工程师、数据科学家,还是对AI底层技术充满好奇的开发者,这个旅程都极具价值。它能帮你摆脱对大型科技公司“黑箱模型”的依赖感,让你在调试模型、进行领域适配或尝试架构创新时,拥有坚实的理论基础和实操直觉。接下来,我将结合工作坊的核心模块,为你拆解其中的关键技术与实操要点,并补充大量原课程中点到即止的细节和我在复现过程中踩过的坑。
2. 项目核心思路与架构总览
2.1 逆向工程式的学习路径
这个工作坊采用了一种非常有效的“逆向工程”式教学法。通常,我们学习深度学习框架是从高层API开始,逐渐深入。但理解LLM,恰恰需要从底层向上构建。项目的设计路线非常清晰:
- 数据层面:从原始文本开始,实现分词器(Tokenizer)和数据加载器(DataLoader),理解模型“吃进去”的是什么。
- 模型层面:逐一实现Transformer的核心组件,如嵌入层、多头注意力机制、前馈网络、层归一化等,最后组装成完整的GPT类模型。
- 训练层面:编写预训练(Pretraining)循环,使用一个小的、公开的文本数据集(如维基百科或小说片段)来训练模型,目标是让模型学会基本的语言建模(预测下一个词)。
- 工程实践:引入LitGPT这个开源库,学习如何加载真实的、大规模预训练好的模型权重(如Llama 2、Mistral),并在此基础上进行指令微调(Instruction Finetuning)。
这个路径的巧妙之处在于,它先用一个“玩具级”的完整流程建立你的全局认知,然后再带你接触工业级工具和模型,让你明白之前手写的代码与成熟库中的实现有何异同,以及如何衔接。
2.2 为什么选择PyTorch和LitGPT?
工作坊选择PyTorch作为基础框架是顺理成章的。PyTorch的动态计算图和直观的API设计,使其非常适合教学、研究和原型开发。你可以像搭积木一样打印、调试每一层的输出,这对于理解模型内部状态流动至关重要。
而LitGPT的选择则体现了从教学到实践的平滑过渡。LitGPT是Lightning AI团队维护的一个开源库,它的核心目标是清晰和可用。与Transformers等大型库相比,LitGPT的代码库更精简,没有过多的抽象层,很多训练、加载模型的脚本可以直接阅读和修改。它支持众多流行的开源模型(如Llama、Phi、Gemma、Mistral),并提供了简洁统一的接口来加载权重和进行微调。通过它,你可以将之前学到的原理,快速应用到真实的预训练模型上,完成一个具体的下游任务(如指令跟随),形成学习闭环。
注意:对于完全零基础的读者,建议先具备基本的Python和PyTorch知识。如果你对张量操作、自动求导、简单的神经网络有了解,那么跟上这个工作坊会顺畅很多。如果缺乏这些基础,可能会在理解模型前向传播和损失计算时遇到障碍。
3. 基石:从文本到张量——数据管道深度解析
任何机器学习项目的成功,一半以上取决于数据。对于LLM,数据管道的核心是将人类可读的文本,转化为模型可处理的数值张量。这一步看似简单,却隐藏着许多设计抉择。
3.1 分词器(Tokenizer)的实现与选择
分词是将文本切分成模型能理解的基本单元(Token)的过程。工作坊中实现了一个基于字节对编码(BPE)的简化分词器。BPE是GPT系列模型使用的算法,其核心思想是从基础字符(如字母)开始,通过迭代合并最高频的相邻符号对,逐步构建出一个词表。
手动实现BPE的核心步骤:
- 初始化:将文本拆分为UTF-8字节序列,每个字节作为一个基础token。
- 统计频率:计算所有相邻token对在语料中出现的频率。
- 合并:找到频率最高的token对,将其合并为一个新的token,并加入词表。
- 迭代:重复步骤2和3,直到词表大小达到预设值(例如,5000)。
# 一个极度简化的BPE合并过程示意 def get_stats(vocab): pairs = collections.defaultdict(int) for word, freq in vocab.items(): symbols = word.split() for i in range(len(symbols)-1): pairs[symbols[i], symbols[i+1]] += freq return pairs def merge_vocab(pair, v_in): v_out = {} bigram = ' '.join(pair) replacement = ''.join(pair) for word in v_in: w_out = word.replace(bigram, replacement) v_out[w_out] = v_in[word] return v_out实操心得:
- 词表大小:这是一个关键超参数。太小(如1k),模型表达能力弱,一个词可能被切成很多片;太大(如100k),则模型参数增多,训练更慢,且可能包含大量低频词。对于教学项目,5k-10k是一个合理的范围。
- 未知词处理:BPE的一个优点是理论上可以编码任何单词,因为它是基于字节的。但实践中,我们仍会设置一个
<unk>token来处理极端情况。 - 特殊Token:必须添加
<bos>(序列开始)、<eos>(序列结束)、<pad>(填充)等特殊token。它们在数据对齐和训练中至关重要。
为什么不用现成的Tokenizer?工作坊要求手写,是为了让你理解“hello world”是如何变成[123, 456]这两个ID的。在实际项目中,我们当然直接使用Hugging Face的tokenizers库或对应模型的官方分词器。
3.2 数据加载器(DataLoader)与上下文窗口
得到Token ID序列后,我们需要将其组织成模型训练所需的批次数据。对于语言模型,训练样本是固定长度的文本片段。
关键实现:
- 分块:将整个语料库的Token ID序列,切割成连续的长度为
block_size(即上下文窗口,如256)的块。 - 构建输入-目标对:对于每个文本块,输入(
x)是前block_size个token,目标(y)是后移一位的block_size个token。模型的任务是根据x预测y。 - 批次生成:随机抽取一批这样的
(x, y)对,组成一个训练批次。
class TextDataset(Dataset): def __init__(self, text, tokenizer, block_size): self.data = tokenizer.encode(text) # 得到token id列表 self.block_size = block_size def __len__(self): return len(self.data) - self.block_size def __getitem__(self, idx): # 取一段连续的token作为输入 x = self.data[idx: idx + self.block_size] # 目标是输入向右移动一位 y = self.data[idx + 1: idx + 1 + self.block_size] return torch.tensor(x), torch.tensor(y)注意事项:
- 上下文窗口(block_size):它决定了模型一次能“看到”多长的历史信息。较小的窗口(128)训练快,但模型记性差;较大的窗口(2048)能处理长文本,但显存消耗呈平方级增长(由于注意力机制)。教学项目通常设为256或512。
- 数据随机化:在
__getitem__中随机选择起始索引idx,可以确保每个epoch的数据顺序都不同,有利于模型泛化。 - 填充(Padding):如果使用批次训练,且序列长度不一致,需要对短序列进行填充。但在语言建模中,我们通常将数据预处理成等长的块,从而避免填充,简化计算。
4. 核心架构:手搓一个微型GPT
这是整个工作坊最硬核、也最令人兴奋的部分。我们将用PyTorch模块组装一个完整的Decoder-Only Transformer模型,也就是GPT的结构。
4.1 核心组件拆解
一个标准的GPT层主要由以下模块构成:
嵌入层(Embedding):将输入的token ID(形状
[batch, seq_len])映射为稠密向量(形状[batch, seq_len, hidden_size])。这里包含两个嵌入层:token_embedding(词嵌入)和position_embedding(位置嵌入)。位置嵌入用于让模型感知token的顺序信息,通常使用可学习的位置编码或正弦余弦固定编码。层归一化(LayerNorm):在Transformer中,层归一化被广泛应用在子层(如注意力、前馈网络)之前或之后(Pre-Norm或Post-Norm)。GPT通常采用Pre-Norm,即先对输入进行归一化,再送入子层。这有助于稳定深层网络的训练。
多头自注意力(Multi-Head Self-Attention):这是Transformer的灵魂。其核心公式是
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V。- 实现步骤: a. 将嵌入向量通过线性变换,拆分成多个头(head),得到Q, K, V。 b. 计算每个头的注意力分数(Q和K的点积,并缩放)。 c. 应用因果掩码(Causal Mask),确保当前位置只能关注到过去的位置,这是语言模型生成未来文本的关键。 d. 对注意力分数做softmax,得到权重。 e. 用权重加权求和V。 f. 将多个头的输出拼接起来,通过一个线性投影层输出。
class CausalSelfAttention(nn.Module): def __init__(self, hidden_size, num_heads): super().__init__() self.num_heads = num_heads self.head_dim = hidden_size // num_heads # 定义Q, K, V的投影层 self.qkv = nn.Linear(hidden_size, 3 * hidden_size) self.proj = nn.Linear(hidden_size, hidden_size) # 因果掩码:下三角矩阵,True的位置将被屏蔽(设为负无穷) self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)) def forward(self, x): B, T, C = x.shape # batch, seq_len, hidden_size qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) # 缩放点积 att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) # 应用因果掩码 att = F.softmax(att, dim=-1) out = att @ v out = out.transpose(1, 2).contiguous().view(B, T, C) out = self.proj(out) return out前馈网络(Feed-Forward Network):一个简单的两层MLP,通常中间有一个扩展因子(如4倍)。公式为
FFN(x) = W2 * GELU(W1 * x + b1) + b2。它为每个位置的表示增加了非线性变换能力。残差连接(Residual Connection):每个子层(注意力、前馈)的输出都会与输入相加,即
output = x + sublayer(x)。这是训练非常深网络的关键技术,能有效缓解梯度消失。
4.2 组装成GPT模型
将多个上述的“Transformer Block”堆叠起来,前面加上嵌入层,后面加上一个用于输出词表概率的线性层(通常称为LM Head),就构成了一个完整的GPT模型。
class MiniGPT(nn.Module): def __init__(self, vocab_size, block_size, hidden_size, num_layers, num_heads): super().__init__() self.token_embed = nn.Embedding(vocab_size, hidden_size) self.pos_embed = nn.Embedding(block_size, hidden_size) self.blocks = nn.Sequential(*[TransformerBlock(hidden_size, num_heads) for _ in range(num_layers)]) self.ln_f = nn.LayerNorm(hidden_size) # 最终层归一化 self.lm_head = nn.Linear(hidden_size, vocab_size) def forward(self, idx): B, T = idx.shape tok_emb = self.token_embed(idx) pos = torch.arange(0, T, device=idx.device) pos_emb = self.pos_embed(pos) x = tok_emb + pos_emb x = self.blocks(x) x = self.ln_f(x) logits = self.lm_head(x) return logits参数选择与经验:
- hidden_size(模型维度):决定了模型表示能力的宽度。教学模型可以设为128或256。
- num_layers(层数):决定了模型的深度。教学模型可以设为6或8层。
- num_heads(注意力头数):通常 hidden_size 需要能被 num_heads 整除。头数越多,模型可以并行关注不同方面的信息。对于小模型,4或8个头是常见选择。
- 初始化:Transformer组件的参数初始化非常重要。通常,线性层的权重会用Xavier或Kaiming初始化,而嵌入层会用较小的标准差(如0.02)进行正态分布初始化。工作坊的代码或LitGPT中会包含合适的初始化方法。
5. 预训练实战:让模型学会“说话”
有了模型和数据,接下来就是最耗资源的环节——预训练。目标是通过大量文本,让模型学会预测下一个词,从而获得通用的语言知识。
5.1 训练循环与损失函数
语言模型的预训练是标准的自监督学习。我们使用交叉熵损失(Cross-Entropy Loss)来衡量模型预测的概率分布与真实的下一个词(one-hot编码)之间的差距。
核心训练循环伪代码:
model = MiniGPT(...).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) scaler = torch.cuda.amp.GradScaler() # 混合精度训练,节省显存 for epoch in range(num_epochs): for batch_idx, (x, y) in enumerate(train_loader): x, y = x.to(device), y.to(device) optimizer.zero_grad() with torch.cuda.amp.autocast(): logits = model(x) # 形状: [batch, seq_len, vocab_size] # 将logits和y reshape成二维,方便计算损失 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() # 定期记录损失,评估模型关键技巧与参数:
- 学习率:3e-4是Transformer模型常用的初始学习率。可以使用学习率预热(Warmup)策略,在训练初期从小学习率逐步增加到设定值,有助于稳定训练。
- 优化器:AdamW(带权重衰减的Adam)是目前的主流选择。其权重衰减参数(通常设为0.1或0.01)对于防止过拟合很重要。
- 梯度裁剪:当梯度范数超过某个阈值时,将其缩放。这可以防止训练不稳定和梯度爆炸。
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)。 - 混合精度训练:使用
torch.cuda.amp可以显著减少显存占用并加速训练,尤其对于大模型至关重要。
5.2 文本生成(推理)
训练过程中,我们需要定期评估模型是否真的学会了语言。最直观的方式就是让它生成文本。
自回归生成(Autoregressive Generation):
- 给定一个起始提示(prompt),例如
“The weather today is”,将其分词并输入模型。 - 模型输出最后一个位置对所有词表token的预测概率(logits)。
- 通过采样策略(如贪婪采样、核采样、温度采样)从概率分布中选出一个token作为下一个词。
- 将新生成的token拼接到输入序列末尾,作为新的输入,重复步骤2-3,直到生成指定长度或遇到结束符。
def generate(model, prompt, max_new_tokens=50, temperature=1.0): model.eval() tokens = tokenizer.encode(prompt) for _ in range(max_new_tokens): # 只取最后block_size个token作为模型输入(如果超过) idx_cond = tokens if len(tokens) <= block_size else tokens[-block_size:] idx_tensor = torch.tensor(idx_cond).unsqueeze(0).to(device) with torch.no_grad(): logits = model(idx_tensor) logits = logits[:, -1, :] / temperature # 取最后一个位置的logits,并应用温度 probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) tokens.append(next_token.item()) if next_token.item() == eos_token_id: # 遇到结束符则停止 break return tokenizer.decode(tokens)采样策略解析:
- 贪婪采样:总是选择概率最高的token。生成结果确定但容易重复、枯燥。
- 温度采样:通过温度参数
T控制分布的平滑度。T=1使用原始分布;T>1使分布更平缓,增加多样性;T<1使分布更尖锐,增加确定性。 - Top-k / Top-p(核采样):只从概率最高的k个token中采样,或从累积概率达到p的最小token集合中采样。这能避免采样到低质量的生僻词,是实践中常用的方法。
实操心得:在预训练初期,模型生成的文本几乎是乱码。随着损失下降,你会慢慢看到它开始学会拼写单词、组合短语、甚至模仿简单的语法结构。这个过程非常有趣,也是检验模型学习进度的最好方式。由于算力限制,教学项目的预训练数据量和步数有限,模型最终可能只会生成一些简单的、符合统计规律的句子,但这足以证明整个流程是跑通的。
6. 加载预训练权重与使用LitGPT
从头预训练一个有用的LLM对个人开发者来说几乎是不可能的。因此,工作坊的后半部分转向了更实际的路径:加载开源社区预训练好的强大模型,并对其进行微调以适应特定任务。
6.1 权重加载:将知识注入你的架构
工作坊会指导你如何将下载的预训练模型权重(例如,一个Hugging Face格式的Llama 2模型),加载到你之前手写的MiniGPT架构中。这个过程本质上是一个“键名映射”的练习。
核心步骤:
- 下载权重:从Hugging Face或模型官方渠道下载包含
pytorch_model.bin或safetensors文件的权重。 - 检查架构对齐:确保你的模型定义(层数、隐藏大小、头数等)与预训练权重完全匹配。一个参数对不上都会导致加载失败。
- 键名映射:预训练权重的状态字典(state_dict)中的键名(如
transformer.h.0.attn.c_attn.weight)需要与你模型中对应参数的名字(如blocks.0.attn.qkv.weight)正确匹配。你需要编写一个映射函数来建立这种对应关系。 - 严格加载:使用
model.load_state_dict(state_dict, strict=True/False)加载。strict=False可以允许部分不匹配,但最好还是做到完全匹配。
这个过程让你深刻理解,模型架构就像一副骨架,而预训练权重则是附着其上的肌肉和神经。不同的骨架(架构)无法直接使用另一副骨架的肌肉。
6.2 引入LitGPT:站在巨人的肩膀上
手动处理权重加载、训练循环、分布式训练等非常繁琐。LitGPT将这些工程细节封装成了简洁易用的命令行工具和API。
LitGPT的核心优势:
- 统一的模型接口:通过一个简单的命令,如
litgpt download --repo_id meta-llama/Llama-2-7b-hf,就能下载并准备好模型。 - 简洁的微调脚本:LitGPT提供了清晰的Python脚本,用于进行全参数微调、LoRA等高效微调。
- 开箱即用的基础设施:它内置了对于FSDP(完全分片数据并行)、混合精度训练、梯度累积等高级训练技术的支持,让你可以更专注于任务和数据,而不是工程调试。
使用LitGPT进行指令微调的典型流程:
- 准备数据集:将你的指令-回答对整理成JSONL格式,每条数据包含
“instruction”、“input”(可选)、“output”字段。 - 下载基础模型:使用
litgpt download命令。 - 运行微调脚本:使用类似
litgpt finetune lora --data_dir your_data/ --checkpoint_dir checkpoints/llama2-7b/的命令,启动LoRA微调。 - 合并与推理:微调完成后,可以将LoRA适配器权重与基础模型合并,然后使用
litgpt generate命令进行对话测试。
从手写代码到使用LitGPT,你会感受到生产力质的飞跃。这正是一个AI工程师的标准工作流:深入理解原理,然后熟练运用高效工具来解决实际问题。
7. 微调策略详解:让通用模型为你所用
预训练模型拥有广博的知识,但要让其遵循指令、适应特定风格或领域,就需要微调。工作坊重点介绍了指令微调。
7.1 指令微调数据准备
指令微调的数据质量至关重要。一个糟糕的数据集会让模型学会错误的模式。
高质量指令数据的特征:
- 多样性:涵盖多种任务类型,如问答、摘要、创作、代码生成、推理等。
- 清晰的格式:指令明确,输出质量高。例如,
“写一首关于春天的诗”对应一首优美的诗歌。 - 多轮对话:可以包含多轮对话数据,训练模型的上下文理解能力。
数据格式示例(JSONL):
{"instruction": "将以下英文翻译成中文。", "input": "Hello, world!", "output": "你好,世界!"} {"instruction": "用一句话总结下面这段话。", "input": "Transformer架构是当前大语言模型的基础...", "output": "Transformer是一种基于自注意力机制的神经网络架构,已成为大语言模型的核心。"}实操心得:对于个人项目,可以从Alpaca、ShareGPT等开源指令数据集中筛选和清洗。甚至可以使用GPT-4等强大模型,为自己的领域数据生成高质量的指令-输出对,构建专属数据集。数据量不一定需要极大,几千条高质量数据就能对7B规模的模型产生显著影响。
7.2 高效微调技术:LoRA与QLoRA
全参数微调需要更新模型所有参数,计算和存储成本极高。LoRA(Low-Rank Adaptation)是一种参数高效的微调方法。
LoRA原理简述: 它冻结预训练模型的权重,只在Transformer层的注意力模块中,注入可训练的“低秩适配器”。具体来说,对于原有的权重矩阵W,LoRA引入两个小的矩阵A和B,使得前向传播变为h = Wx + BAx。其中,A和B的秩r很小(如8、16),因此可训练参数数量剧减(可能只有原模型的0.1%)。
使用LoRA的优势:
- 显存占用低:因为大部分参数被冻结,只需要存储和优化适配器参数及对应的梯度。
- 训练速度快:参数少,优化步骤自然更快。
- 产出模型小:只需保存小小的适配器权重(几MB到几十MB),而非整个模型(几GB到几十GB)。
- 可切换任务:同一个基础模型可以搭配多个不同的LoRA适配器,快速切换不同任务。
QLoRA则在LoRA的基础上更进一步,将基础模型的权重量化为4位精度(如NF4),并在训练时以一种特殊的方式维持高精度梯度。这使得在单张消费级显卡(如24GB显存的RTX 4090)上微调70B参数的大模型成为可能。
在LitGPT中,使用LoRA进行微调非常简单,几乎只需在命令行中指定--lora参数即可。这让你能将有限的算力集中在最重要的参数更新上。
8. 常见问题、调试技巧与避坑指南
在复现这个工作坊或进行类似项目时,你几乎一定会遇到下面这些问题。这里记录了我的排查思路和解决方案。
8.1 训练过程不稳定,损失值出现NaN
这是新手最常见的问题之一。
- 检查梯度:在训练循环中添加梯度范数打印。如果梯度范数突然变得极大(如超过100),很可能导致数值溢出。
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) print(f"Gradient norm: {total_norm}") - 启用梯度裁剪:如上所示,设置一个合理的
max_norm(如1.0或5.0)。 - 检查学习率:过高的学习率是罪魁祸首。尝试降低学习率(例如从3e-4降到1e-4),或启用学习率预热。
- 检查数据:确保输入数据中没有异常值(如非常大的索引号,超出了词表范围)。确保Tokenizer正常工作。
- 使用混合精度训练:
torch.cuda.amp能自动处理数值精度,有时能缓解不稳定问题。
8.2 模型不收敛,损失值居高不下
- 模型太小或数据太复杂:确认你的模型容量(参数量)是否足以拟合任务。对于教学项目,如果数据是复杂的长文本,而模型只有几万参数,那可能确实学不会。可以尝试简化数据(如使用儿童故事)或稍微增大模型(增加
hidden_size或num_layers)。 - 初始化问题:检查模型参数初始化。错误的初始化可能导致信号在深层网络中消失或爆炸。确保使用了适合Transformer的初始化方法(如PyTorch默认的初始化对于Embedding层可能偏大)。
- 损失函数计算:确认交叉熵损失函数的输入
logits和target的形状是否正确。一个常见的错误是target没有正确地从[batch, seq_len]reshape为[batch*seq_len]。 - 验证数据:确保你的训练数据是有效的。可以打印几个批次的数据,用Tokenizer解码回去,看看是不是正常的句子。
8.3 生成文本重复或没有意义
- 采样温度:尝试调整生成时的温度(
temperature)。temperature=0.0等价于贪婪解码,容易导致重复。将其设为0.7-1.0之间,并尝试结合Top-p采样(如top_p=0.9)。 - 模型训练不足:损失值还在下降吗?如果模型只训练了几个epoch,它可能只学会了最简单的字符组合。继续训练,并观察验证集损失是否持续下降。
- 检查因果掩码:确保在训练和推理时,注意力机制中的因果掩码(Causal Mask)被正确应用。如果掩码失效,模型在训练时就能“偷看”到未来的答案,导致它无法学会真正的自回归生成。
8.4 显存不足(OOM)
- 减小批次大小:这是最直接有效的方法。
- 减小序列长度:上下文窗口
block_size是显存消耗的大头(因为注意力矩阵是seq_len的平方)。尝试将其减半。 - 使用梯度累积:如果单卡批次大小只能设为1,可以通过梯度累积来模拟更大的批次。例如,每4个前向传播步骤累积一次梯度,再执行一次参数更新,这等价于批次大小为4,但显存占用仅为批次大小为1的水平。
- 启用梯度检查点:对于非常大的模型,可以使用
torch.utils.checkpoint来以计算时间换取显存空间。它会只保留部分中间变量,在反向传播时重新计算。 - 使用LitGPT/FSDP:对于真正的多卡训练,使用LitGPT内置的FSDP支持,可以将模型参数、梯度和优化器状态分片到多张卡上。
8.5 加载预训练权重时报错
- 键名不匹配:仔细对比预训练权重状态字典的键和你模型状态字典的键。编写一个详细的映射字典。使用
print(model.state_dict().keys())和print(pretrained_dict.keys())来辅助排查。 - 形状不匹配:这是最关键的。确保每一层对应的权重张量形状完全一致。例如,你的
token_embed.weight形状是[vocab_size, hidden_size],那么预训练权重的对应项也必须是这个形状。如果词表大小不同,可能需要截取或特殊处理。 - 数据类型和设备:确保将权重加载到正确的设备(CPU/GPU)上,并注意数据类型(float16/float32)。
走通从零构建LLM的整个流程,是一次无与伦比的学习体验。它剥开了LLM神秘的外衣,让你看到其下精妙而优雅的工程结构。虽然你亲手训练的“小模型”远不及GPT-4强大,但这份对底层原理的深刻理解,将成为你后续使用、调优乃至创新大模型技术的坚实基石。当你再看到一篇关于新架构的论文,或需要为特定任务定制一个模型时,你会清楚地知道该从何处入手,如何评估,以及可能面临哪些挑战。这就是动手实践的价值。
