LLM 训练:从预训练到微调
1. 技术分析
1.1 LLM 训练流程
LLM 训练分为预训练和微调两个阶段:
训练流程 预训练: 大规模无监督训练 微调: 特定任务训练 RLHF: 人类反馈强化学习
1.2 预训练 vs 微调
| 阶段 | 数据 | 目标 | 方法 |
|---|
| 预训练 | 大规模文本 | 语言建模 | 自监督学习 |
| 微调 | 任务数据 | 特定任务 | 监督学习 |
| RLHF | 人类反馈 | 对齐人类偏好 | 强化学习 |
1.3 训练策略
训练策略 预训练: 下一个 token 预测 微调: 监督微调(SFT) RLHF: 奖励模型 + PPO
2. 核心功能实现
2.1 预训练数据处理
import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader class TextDataset(Dataset): def __init__(self, text, tokenizer, max_len=512): self.text = text self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.text) def __getitem__(self, idx): text = self.text[idx] encoding = self.tokenizer( text, truncation=True, max_length=self.max_len, padding='max_length', return_tensors='pt' ) input_ids = encoding['input_ids'].flatten() labels = input_ids.clone() return {'input_ids': input_ids, 'labels': labels} class DataCollator: def __call__(self, batch): input_ids = torch.stack([item['input_ids'] for item in batch]) labels = torch.stack([item['labels'] for item in batch]) return {'input_ids': input_ids, 'labels': labels} class TextDataLoader: def __init__(self, texts, tokenizer, batch_size=32, max_len=512): self.dataset = TextDataset(texts, tokenizer, max_len) self.collator = DataCollator() self.dataloader = DataLoader( self.dataset, batch_size=batch_size, collate_fn=self.collator, shuffle=True ) def __iter__(self): return iter(self.dataloader) def __len__(self): return len(self.dataloader)
2.2 预训练训练器
class PretrainingTrainer: def __init__(self, model, optimizer, scheduler, loss_fn, device='cuda'): self.model = model.to(device) self.optimizer = optimizer self.scheduler = scheduler self.loss_fn = loss_fn self.device = device def train_step(self, batch): self.model.train() self.optimizer.zero_grad() input_ids = batch['input_ids'].to(self.device) labels = batch['labels'].to(self.device) outputs = self.model(input_ids, labels=labels) loss = outputs.loss loss.backward() self.optimizer.step() self.scheduler.step() return loss.item() def train_epoch(self, dataloader): total_loss = 0 for batch in dataloader: loss = self.train_step(batch) total_loss += loss return total_loss / len(dataloader) class DistributedPretrainer: def __init__(self, model, config): self.model = torch.nn.parallel.DistributedDataParallel(model) self.config = config def train(self, dataloader): optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['lr']) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.config['epochs']) for epoch in range(self.config['epochs']): for batch in dataloader: optimizer.zero_grad() input_ids = batch['input_ids'].to('cuda') labels = batch['labels'].to('cuda') outputs = self.model(input_ids, labels=labels) loss = outputs.loss loss.backward() optimizer.step() scheduler.step()
2.3 微调与 RLHF
class SFTTrainer: def __init__(self, model, tokenizer, config): self.model = model self.tokenizer = tokenizer self.config = config def train(self, instruction_response_pairs): optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['lr']) for epoch in range(self.config['epochs']): for instruction, response in instruction_response_pairs: optimizer.zero_grad() prompt = f"Instruction: {instruction}\nResponse: {response}" encoding = self.tokenizer(prompt, return_tensors='pt') outputs = self.model(**encoding, labels=encoding['input_ids']) loss = outputs.loss loss.backward() optimizer.step() class RewardModel(nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model self.reward_head = nn.Linear(base_model.config.hidden_size, 1) def forward(self, input_ids): outputs = self.base_model(input_ids) hidden_states = outputs.last_hidden_state[:, -1, :] reward = self.reward_head(hidden_states) return reward class PPO_trainer: def __init__(self, model, reward_model, config): self.model = model self.reward_model = reward_model self.config = config self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config['lr']) def train_step(self, prompt, response): self.optimizer.zero_grad() prompt_ids = self.tokenizer.encode(prompt, return_tensors='pt') response_ids = self.tokenizer.encode(response, return_tensors='pt') input_ids = torch.cat([prompt_ids, response_ids], dim=1) reward = self.reward_model(input_ids) log_probs = self._compute_log_probs(input_ids) loss = -reward * log_probs loss.backward() self.optimizer.step() def _compute_log_probs(self, input_ids): outputs = self.model(input_ids) logits = outputs.logits log_probs = F.log_softmax(logits, dim=-1) log_probs = log_probs.gather(2, input_ids.unsqueeze(2)).squeeze(2) return log_probs.sum(dim=1)
3. 性能对比
3.1 训练阶段对比
| 阶段 | 数据量 | 计算量 | 目标 |
|---|
| 预训练 | 大规模 | 极高 | 通用能力 |
| SFT | 中等 | 高 | 任务能力 |
| RLHF | 小规模 | 中 | 对齐 |
3.2 训练效率对比
| 策略 | 样本效率 | 计算效率 | 效果 |
|---|
| 预训练 | 低 | 高 | 通用 |
| SFT | 高 | 中 | 任务 |
| RLHF | 中 | 低 | 对齐 |
3.3 模型大小影响
| 参数规模 | 预训练时间 | 微调时间 | 推理速度 |
|---|
| 1B | 1周 | 1天 | 1000 tokens/s |
| 10B | 1月 | 1周 | 500 tokens/s |
| 100B | 6月 | 1月 | 100 tokens/s |
4. 最佳实践
4.1 训练流程
def build_training_pipeline(config): if config['stage'] == 'pretrain': return PretrainingTrainer(config) elif config['stage'] == 'sft': return SFTTrainer(config) elif config['stage'] == 'rlhf': return PPO_trainer(config) class LLMTrainingWorkflow: def __init__(self, config): self.config = config def run(self): if self.config['stage'] == 'pretrain': self._run_pretraining() elif self.config['stage'] == 'sft': self._run_sft() elif self.config['stage'] == 'rlhf': self._run_rlhf() def _run_pretraining(self): model = self._initialize_model() dataloader = self._create_dataloader() trainer = PretrainingTrainer(model, self.config) trainer.train(dataloader) def _run_sft(self): model = self._load_pretrained_model() data = self._load_sft_data() trainer = SFTTrainer(model, data, self.config) trainer.train() def _run_rlhf(self): model = self._load_sft_model() reward_model = self._train_reward_model() trainer = PPO_trainer(model, reward_model, self.config) trainer.train()
4.2 训练优化
class TrainingOptimizer: def __init__(self, model): self.model = model def enable_mixed_precision(self): self.scaler = torch.cuda.amp.GradScaler() def enable_gradient_checkpointing(self): self.model.gradient_checkpointing_enable() def enable_distributed_training(self): self.model = torch.nn.parallel.DistributedDataParallel(self.model) def apply_all(self): self.enable_mixed_precision() self.enable_gradient_checkpointing()
5. 总结
LLM 训练是复杂的工程任务:
- 预训练:大规模无监督学习
- 微调:监督学习适配特定任务
- RLHF:强化学习对齐人类偏好
- 训练优化:混合精度、分布式训练
对比数据如下:
- RLHF 显著提升模型对齐能力
- 混合精度训练可节省 50% 内存
- 梯度检查点可节省 30-50% 内存
- 推荐使用 LoRA 进行高效微调