别再死记硬背公式了!用PyTorch代码实战FGM、PGD和FreeLB,手把手教你提升NLP模型鲁棒性
对抗训练实战:用PyTorch代码解析FGM、PGD与FreeLB的核心实现差异
当你第一次在论文里看到对抗训练的Min-Max公式时,是否感觉像在读天书?那些关于"内层最大化扰动,外层最小化损失"的理论描述,往往让工程师们陷入"懂了但不会写"的困境。今天我们不谈数学推导,直接深入代码层面,通过对比FGM、PGD和FreeLB三种经典算法的PyTorch实现,带你掌握对抗训练的真实落地技巧。
1. 对抗训练的本质:代码视角下的双重梯度更新
对抗训练的核心思想其实可以用两行伪代码概括:
# 内层:寻找使损失最大的扰动 perturbation = gradient_ascent(original_loss) # 外层:用扰动样本更新模型 model_update(adversarial_loss)但在实际实现中,不同算法对这两个步骤的处理方式差异显著。以最常见的NLP任务为例,对抗扰动通常施加在embedding层,因为:
- 文本离散性导致无法直接在词ID上做扰动
- embedding空间连续可微,适合梯度计算
- 下游任务性能对embedding变化敏感
关键实现差异点在于:
- 扰动计算方式(一步到位vs迭代优化)
- 梯度累积策略(覆盖vs累加)
- 参数恢复时机(权重vs梯度)
2. FGM实现解析:一步到位的对抗样本生成
Fast Gradient Method(FGM)是最轻量级的对抗训练方法,其核心思想是在梯度方向上一步到位地添加扰动。以下是需要特别注意的实现细节:
class FGM: def attack(self, epsilon=1., emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if emb_name in name and param.requires_grad: self.backup[name] = param.data.clone() # 备份原始embedding norm = torch.norm(param.grad) # 计算梯度范数 if norm != 0: r_at = epsilon * param.grad / norm # 归一化扰动 param.data.add_(r_at) # 施加扰动 def restore(self, emb_name='word_embeddings'): for name, param in self.model.named_parameters(): if emb_name in name and param.requires_grad: param.data = self.backup[name] # 恢复原始embedding self.backup = {}典型使用陷阱:
- 忘记在attack前执行
loss.backward()会导致梯度为None - 错误指定emb_name导致扰动未应用到目标层
- 在restore之前调用optimizer.step()会造成参数污染
提示:FGM的epsilon参数需要精细调校,一般从0.05开始尝试,过大可能导致模型性能下降
3. PGD实现精讲:迭代式对抗攻击的工程细节
Projected Gradient Descent(PGD)通过多步小扰动来提升对抗样本质量,其实现复杂度显著高于FGM。关键实现组件包括:
class PGD: def attack(self, is_first_attack=False): for name, param in self.model.named_parameters(): if self.emb_name in name and param.requires_grad: if is_first_attack: self.emb_backup[name] = param.data.clone() # 首次备份 norm = torch.norm(param.grad) if norm != 0: r_at = self.alpha * param.grad / norm # 计算单步扰动 param.data.add_(r_at) param.data = self.project(name, param.data) # 投影到约束空间 def project(self, param_name, param_data): # 将扰动限制在ε-ball内 r = param_data - self.emb_backup[param_name] if torch.norm(r) > self.epsilon: r = self.epsilon * r / torch.norm(r) return self.emb_backup[param_name] + rPGD训练循环中的关键时序:
- 正常前向传播计算原始loss
- 备份原始梯度(
backup_grad) - 进行K步对抗迭代:
- 每步计算当前扰动并更新embedding
- 非最后一步时清零梯度
- 最后一步恢复原始梯度
- 恢复原始embedding参数
- 执行参数更新
# 典型训练循环结构 for batch in dataloader: loss = model(batch) # 原始前向 loss.backward() # 原始反向 pgd.backup_grad() # 梯度备份 for t in range(K): # K步对抗 pgd.attack(is_first_attack=(t==0)) if t != K-1: model.zero_grad() else: pgd.restore_grad() loss_adv = model(batch) loss_adv.backward() pgd.restore() # 恢复embedding optimizer.step() # 参数更新4. FreeLB的创新实现:梯度累积的对抗策略
FreeLB(Free Large-Batch)通过梯度累积实现更高效的对抗训练,其核心创新点在于:
- 在整个对抗过程中不重置梯度
- 使用累积梯度更新模型参数
- 支持动态调整扰动幅度
实现关键点解析:
def attack(self, model, inputs): embeds_init = get_embeddings(model, inputs) # 获取初始embedding delta = self.initialize_delta(embeds_init) # 扰动初始化 for astep in range(self.adv_K): delta.requires_grad_() # 启用扰动梯度 inputs['inputs_embeds'] = embeds_init + delta # 应用扰动 outputs = model(**inputs) loss = outputs[0] loss.backward() # 梯度累积 # 更新扰动 delta_grad = delta.grad.detach() if self.adv_norm_type == "l2": denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1) delta = (delta + self.adv_lr * delta_grad / denorm).detach() # 投影操作省略... return model(**inputs) # 返回最终结果参数调优经验:
| 参数 | 推荐范围 | 作用说明 |
|---|---|---|
| adv_K | 3-5 | 对抗步数,过多易导致过拟合 |
| adv_lr | 1e-2 | 扰动学习率 |
| adv_init_mag | 1e-2 | 初始扰动幅度 |
| adv_max_norm | 0.5-2.0 | 最大扰动约束 |
5. 三大算法实战对比与选型建议
在实际项目中如何选择合适的对抗训练方法?以下是从工程角度总结的对比维度:
计算效率对比:
- FGM:额外计算开销约20-30%
- PGD:K倍计算开销(通常K=3)
- FreeLB:约1.5倍于FGM的开销
实现复杂度对比:
# 代码复杂度评分(1-5分,越高越复杂) complexity = { 'FGM': 2, # 只需实现attack/restore 'PGD': 4, # 需管理梯度/参数双重备份 'FreeLB': 3 # 需处理梯度累积逻辑 }效果对比建议:
- 资源有限时首选FGM
- 追求极致效果可尝试PGD
- 大批量训练时FreeLB更高效
- 结合早停策略防止过拟合
常见坑点解决方案:
- 梯度消失:检查扰动是否过小
- 训练震荡:降低epsilon或adv_lr
- 性能下降:尝试冻结底层参数
- OOM错误:减小batch_size
# 鲁棒性测试代码片段 def test_robustness(model, test_loader, attack_method): model.eval() total = 0 correct = 0 for inputs, labels in test_loader: # 生成对抗样本 adv_inputs = attack_method.generate(inputs, labels) outputs = model(adv_inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return correct / total在BERT-base模型上的实测数据显示,合理使用对抗训练可以使文本分类任务的对抗样本准确率提升15-20%,同时普通样本准确率也有2-3%的提升。不过要注意,对抗训练通常会延长30-50%的训练时间,需要在效果和效率之间做好权衡。
