当前位置: 首页 > news >正文

使用torch.compile与梯度累积加速模型训练

训练一个具有深度Transformer架构的语言模型是耗时的。然而,有些技巧可以用来加速训练。在本文中,你将学习到:

  • 使用 torch.compile() 加速模型
  • 使用梯度累积来训练具有更大有效批次大小的模型

让我们开始吧!

概述

本文分为两个部分:

  • 使用 torch.compile()
  • 梯度累积

使用 torch.compile

当你在PyTorch中编写并运行模型代码时,它是在eager模式下执行的。这意味着代码是一行一行执行的,结果存储在内存中。这是Python的原生方式,因为它是一种解释型语言。你知道这一点是因为当代码出现错误时,只有运行到该行时才会看到错误提示。

在eager模式下运行模型速度较慢。从PyTorch 2.0开始,你可以使用torch.compile()来编译模型以提高性能。这会生成一个经过优化的新模型对象。它不是你用nn.Module创建的原始模型对象,但它与原始模型共享相同的张量。你可以像往常一样使用这个编译后的模型进行前向传播、反向传播和优化器更新。

将模型构建并编译成计算图正是TensorFlow 1.0的设计思路。这使得调试更加困难,因为你执行的模型无法与你编写的代码逐行对应。因此,在运行试验并确认模型没有错误之前,你不应该编译模型。

并非所有模型都可以编译。但是,如果你的模型支持编译,你将立即受益于速度提升。要编译一个模型,你只需要在准备使用模型之前替换模型对象:

... model = LlamaForPretraining(model_config).to(device) model.load_state_dict(checkpoint) model = torch.compile(model) ...

不要在编译后加载模型权重。这是因为编译后的模型是一个与原始模型共享权重的对象。在编译过程中,构建的计算图引用了原始模型的权重张量。如果你在编译后加载权重,模型可能无法按预期工作。

同样,要保存编译后的模型,你应该引用原始模型的状态字典,如下所示:

torch.save(getattr(model, "_orig_mod", model).state_dict(), "model.pth")

可以通过model._orig_mod访问编译模型中的原始模型。在上面的代码中,我们使用getattr(model, "_orig_mod", model)来获取原始模型(如果存在),或者如果不存在则使用模型本身。这行代码对编译模型和原始模型都适用。

梯度累积

当你训练一个模型时,你在反向传播上花费的时间可能是前向传播的两到三倍。这是因为反向传播计算强度更大,并且占用更多内存。

一个简单的加速训练技巧是减少反向传播的次数。这可以通过增加批次大小来实现:对于相同数量的数据样本,更大的批次大小意味着要处理的批次更少。

然而,更大的批次大小需要更多内存。在内存受限的环境中,你可以通过运行多次前向传播并累积梯度来模拟更大的批次大小。这被称为梯度累积

用代码来解释这个想法更容易:

.. accumulate_steps = 4 for epoch in range(num_epochs): optimizer.zero_grad() for i, batch in enumerate(dataloader): # 获取批次数据 input_ids, target_ids = batch # 创建注意力掩码:因果掩码 + 填充掩码 attn_mask = create_causal_mask(input_ids.shape[1], device) + \ create_padding_mask(input_ids, PAD_TOKEN_ID, device) # 从模型提取输出 logits = model(input_ids, attn_mask) # 计算损失:logits与目标之间的交叉熵,忽略填充标记 loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1)) loss = loss / accumulate_steps # 运行反向传播,但每`accumulate_steps`步才更新一次 loss.backward() if (i + 1) % accumulate_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() scheduler.step()

上面的训练循环摘自上一篇关于在本地GPU上训练Llama模型的文章。

通常,当你运行一次前向传播时,你会计算损失。然后调用loss.backward()通过模型参数反向传播损失梯度。在PyTorch中,backward()方法是累积的,这意味着梯度是相加的。因此,你需要在运行反向传播之前显式调用optimizer.zero_grad()来清除梯度。

在上面的代码中,你故意不在每次迭代中都调用optimizer.zero_grad()。相反,你对损失(除以accumulate_steps)运行反向传播。这样,梯度被缩小但在accumulate_steps次迭代中累积。每经过accumulate_steps次迭代,你才运行优化器来调整模型参数。

这种方法产生的结果与使用更大批次大小获得的结果相当。然而,由于你运行的优化器更新次数更少,学习率调度器应相应调整。这意味着你需要用不同的步数来初始化调度器:

... num_training_steps = (len(dataloader) // accumulate_steps) * num_epochs cosine_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_training_steps - num_warmup_steps, eta_min=0 )

进一步阅读

以下是一些你可能感兴趣的资料:

  • torch.compile 文档
  • PyTorch 文档中的自动混合精度示例

总结

在本文中,你了解到使用torch.compile()可以通过编译计算图来帮助你加速模型。你还了解到,梯度累积是一种通过累积多个小批次的梯度来训练更大有效批次大小的技术。由于这种方式减少了优化器更新次数,你可以节省反向传播和参数更新的时间。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)

http://www.jsqmd.com/news/298692/

相关文章:

  • 王同學文章轉載公告
  • 做社区垃圾分类指导工具,输入垃圾名称,自动识别所属分类(可回收/厨余/有害/其他),标注时间,投放点,附详细分类说明。
  • DBShadow.net之化繁为简
  • 无锡、杭州升降平台厂商价格比较,哪个更合适?
  • 聊聊同质透心pvc地板精品定制,新凯琳厂家优势在哪?
  • 南京整装装修专业公司哪家好,红牛装饰实力揭秘
  • 计算机毕业设计springboot房屋租赁管理系统 基于SpringBoot的在线房屋出租与求租撮合平台 SpringBoot+Vue智慧住房租赁综合服务平台
  • 计算机毕业设计springboot房屋租赁管理系统 基于SpringBoot的在线房源租售一体化运营平台 SpringBoot+Thymeleaf智慧住房租赁合约管理系统
  • 工程防火材料选型指南:深度解读消防验收标准下的廊坊大浩防火涂料方案,防火涂料/电缆防火涂料,防火涂料供应商排行榜单
  • 2026火锅测评:哪些网红品牌值得一试?重庆火锅/美食/社区火锅/牛肉火锅/老火锅/附近火锅/火锅,火锅品牌推荐排行
  • 如何使用 GitHub Actions + image-syncer 实现 Docker Hub 到 Azure ACR 的自动化镜像同步
  • 计算机毕业设计springboot防诈知识在线学习系统 基于SpringBoot的反诈骗科普互动学习平台 SpringBoot+Vue智慧防诈在线教育系统
  • Java毕设项目:基于springboot的私厨服务平台的设计与实现(源码+文档,讲解、调试运行,定制等)
  • 计算机毕业设计springboot房屋租赁系统 基于SpringBoot的在线房屋出租与求租撮合平台 SpringBoot+Vue智慧住房租赁综合服务平台
  • 计算机毕业设计springboot房车旅途 基于SpringBoot的房车租赁与售卖一体化平台 SpringBoot+Vue智慧房车出行服务系统
  • 计算机毕业设计springboot房源出租信息系统 基于SpringBoot的在线房屋租售一体化平台 SpringBoot+Vue智慧房源租赁撮合系统
  • Java计算机毕设之基于java+springboot的花店鲜花销售管理系统基于springboot的鲜花销售管理系统的设计与实现(完整前后端代码+说明文档+LW,调试定制等)
  • Java计算机毕设之基于Springboot的社区老年人健康管理系统基于springboot的社区独居老人健康管理系统(完整前后端代码+说明文档+LW,调试定制等)
  • 2026年靠谱的微压富氧舱解决方案大揭秘,原力氧FVIP值得关注
  • 【毕业设计】基于springboot的私厨服务平台的设计与实现(源码+文档+远程调试,全bao定制等)
  • 【课程设计/毕业设计】基于springboot+vue的社区独居老人健康管理系统基于springboot的社区空巢老人健康管理系统 社区独居老人健康管理系统【附源码、数据库、万字文档】
  • 肯能机械口碑如何,江西包装机厂家排名有它吗?
  • 【毕业设计】基于springboot的社区独居老人健康管理系统(源码+文档+远程调试,全bao定制等)
  • 说说升降平台供应企业哪家专业,无锡固佳工业设备上榜
  • 润滑油泵选型怎么选,专业支招让你不再迷茫
  • 用一套开源AI视觉系统,提升商场促销活动效果
  • 深入解析:MyBatis框架 - 延迟加载+一/二级缓存
  • 深聊氨基酸洗发产品,好用的品牌都在这了
  • auto类型和范围for循环
  • 导师推荐2026最新!9款AI论文工具测评:专科生毕业论文必备