大模型训练流程实战:从预训练到推理的完整技术解析
大模型训练流程实战:从预训练到推理的完整技术解析
导读:本文系统梳理大语言模型从预训练到推理的完整技术链路,涵盖数据工程、SFT微调、RLHF对齐、推理机制、幻觉治理等核心环节,结合实战代码与深度案例,帮助开发者建立端到端的训练认知。
第一章 核心认知:大模型的本质是什么
大模型不是写规则写出来的,而是通过"预测下一个token"把海量文本规律压进参数,推理时根据当前上下文逐token生成。
这个框架理解后,所有概念都顺了。
1.1 训练与推理的本质区别
| 阶段 | 目标 | 输入 | 输出 | 关键操作 |
|---|---|---|---|---|
| 预训练 | 学习语言规律 | 海量文本 | 下一个token预测 | 反向传播更新参数 |
| SFT | 学会指令格式 | 问答对样本 | 优质回答 | 监督微调 |
| 偏好对齐 | 学会"好回答" | 偏好排序数据 | 符合偏好的输出 | RLHF/DPO优化 |
| 推理 | 生成用户想要的回答 | 用户prompt | 逐token生成文本 | 前向传播+采样 |
1.2 参数到底是什么
模型参数不是"第9527个参数存着Redis为什么快"这种精确知识,而是高维空间里的规律——大量参数共同形成一种分布式表示。
# 参数本质:神经连接强度矩阵importtorchimporttorch.nnasnn# 一个简单的Transformer层参数结构classSimpleTransformer(nn.Module):def__init__(self,d_model=768,n_heads=12):super().__init__()# 参数是权重矩阵,不是"知识条目"self.W_q=nn.Linear(d_model,d_model)# 查询投影self.W_k=nn.Linear(d_model,d_model)# 键投影self.W_v=nn.Linear(d_model,d_model)# 值投影self.W_o=nn.Linear(d_model,d_model)# 输出投影defforward(self,x):q,k,v=self.W_q(x),self.W_k(x),self.W_v(x)# 注意力机制:参数共同协作捕捉序列模式scores=torch.matmul(q,k.transpose(-2,-1))/(d_model**0.5)attn=torch.softmax(scores,dim=-1)returntorch.matmul(attn,v)关键洞察:预训练就是反复调整这些连接强度,让模型学会"什么样的上下文后面应该接什么token"。
第二章 Token化:模型看到的不是文字
2.1 Tokenization 核心原理
文本 → tokenizer → token序列 → 映射为id → 模型处理
模型不是在"字符"层面理解,而是在"token序列"上学习。tokenization策略直接影响模型效果。
2.2 主流Tokenizer对比
| Tokenizer | 代表模型 | 词表大小 | 特点 | 适用场景 |
|---|---|---|---|---|
| BPE | GPT系列 | 50K | 基于字节对编码 | 英文为主 |
| WordPiece | BERT | 30K | 基于词片段 | NLU任务 |
| SentencePiece | T5 | 32K | 无空格预处理 | 多语言 |
| Unigram | ALBERT | 30K | 概率模型 | 高效分词 |
| TikToken | GPT-4 | 100K | 基于正则+统计 | 高效推理 |
2.3 实战:使用HuggingFace Tokenizer
fromtransformersimportAutoTokenizer# 加载GPT-2的tokenizertokenizer=AutoTokenizer.from_pretrained("gpt2")# 文本编码text="大模型训练是预测下一个token的过程"encoding=tokenizer(text)print(f"原始文本:{text}")print(f"Token IDs:{encoding['input_ids']}")print(f"Token数量:{len(encoding['input_ids'])}")print(f"解码回文本:{tokenizer.decode(encoding['input_ids'])}")# 查看每个token对应的文本fori,token_idinenumerate(encoding['input_ids']):print(f" Token{i}:{token_id}-> '{tokenizer.decode([token_id])}'")2.4 Tokenization 实战技巧
# 技巧1:控制最大长度encoding=tokenizer(text,max_length=512,truncation=True,padding='max_length')# 技巧2:添加特殊tokenspecial_tokens={'bos_token':'<s>','eos_token':'</s>','unk_token':'<unk>'}tokenizer.add_special_tokens(special_tokens)# 技巧3:批量编码(推理时常用)texts=["你好","大模型很强大","训练需要大量数据"]batch_encoding=tokenizer(texts,padding=True,truncation=True,return_tensors="pt")print(f"Batch shape:{batch_encoding['input_ids'].shape}")# 技巧4:计算token数量(预估推理成本)defcount_tokens(text,model_name="gpt2"):tok=AutoTokenizer.from_pretrained(model_name)returnlen(tok.encode(text))print(f"估算token数:{count_tokens(text)}")第三章 数据工程:脏数据会把模型带歪
3.1 数据质量决定模型上限
不是越多越好,要清洗、去重、质量打分、过滤敏感内容。很多模型能力差,不是架构不行,是数据工程没做好。
3.2 数据清洗流水线
importrefromcollectionsimportCounterclassDataCleaner:"""数据清洗流水线"""def__init__(self):self.bad_patterns=[r'<script.*?>.*?</script>',# 移除脚本r'<style.*?>.*?</style>',# 移除样式r'http[s]?://\S+',# 移除URLr'[^\u4e00-\u9fff\w\s\.,!?,。!?]',# 保留中英文和标点]defclean(self,text):# 1. 移除HTML标签和URLforpatterninself.bad_patterns:text=re.sub(pattern,'',text)# 2. 移除多余空白text=re.sub(r'\s+',' ',text).strip()# 3. 移除过短内容iflen(text)<10:returnNonereturntextdefquality_score(self,text):"""简单质量打分"""score=0# 长度分score+=min(len(text)/1000,3)# 中文比例分chinese_chars=len(re.findall(r'[\u4e00-\u9fff]',text))score+=min(chinese_chars/len(text)*2,2)iftextelse0# 标点合理性punct_ratio=len(re.findall(r'[.,!?,。!?]',text))/len(text)iftextelse0score+=min(punct_ratio*5,2)returnscore