当前位置: 首页 > news >正文

在transformer架构里面前向传播和反向传播-gradient checkpoint

目录
    • 1. Transformer 的基本结构回顾
    • 2. 前向传播在 Transformer 中的执行
    • 3. 反向传播在 Transformer 中的执行
    • 4. 哪些层执行前向/反向传播
  • 图例
      • Transformer Block 文本示意
      • 开启 Gradient Checkpoint 后(节省显存)

好的,我们把问题缩小到 Transformer 架构 里,来看前向传播和反向传播是如何执行的。为了清晰,我会分层次讲,从整体到细节。


1. Transformer 的基本结构回顾

一个标准 Transformer 的 Encoder 或 Decoder Block,通常包含以下几层(按顺序):

  1. Multi-Head Attention (MHA)

    • 输入:上一层输出(或 embedding)
    • 输出:经过注意力加权的表示
  2. Add & LayerNorm (残差 + 归一化)

  3. Feed Forward Network (FFN)

    • 两层线性层 + 激活(通常是 GELU 或 ReLU)
  4. Add & LayerNorm (残差 + 归一化)

所以一个 Block 可以画作:

Input → MHA → Add&Norm → FFN → Add&Norm → Output

2. 前向传播在 Transformer 中的执行

原则:从输入到输出,一层层执行前向传播。

  • Embedding 层

    • 输入 token → embedding 向量 + positional encoding
  • 每个 Transformer Block

    1. Multi-Head Attention

      • 计算 Query/Key/Value → 注意力权重 → 输出
      • 保存输出(激活)用于反向传播
    2. Add & LayerNorm

      • 残差连接 + 归一化
      • 保存归一化输出
    3. FFN

      • 线性 → 激活 → 线性
      • 保存 FFN 输出
    4. Add & LayerNorm

      • 最终输出作为下一层输入
  • 最后输出层(比如 LM head)

    • 输出 logits → 预测概率

总结前向传播就是从 Embedding 到最后输出,每层都计算并存激活。


3. 反向传播在 Transformer 中的执行

原则:从输出层开始,沿着网络倒着计算梯度。

  • 输出层

    • 计算 loss 对 logits 的梯度
  • 最后一个 Block

    1. Add&Norm → 梯度向 FFN 和残差路径传播
    2. FFN → 梯度回传到线性层权重
    3. Add&Norm → 梯度回传到 MHA 输出
    4. MHA → 梯度回传到 Query/Key/Value 权重
  • 前一层 Block

    • 重复相同流程
  • Embedding 层

    • 计算 embedding 权重的梯度

注意点

  • 梯度计算需要前向的激活

    • FFN 层的输入
    • MHA 的 Q/K/V
    • LayerNorm 的输入
  • 如果用了 gradient checkpoint

    • 不保存某些中间激活
    • 反向传播时需要重新计算前向来得到激活,再算梯度

4. 哪些层执行前向/反向传播

层类型 前向传播 反向传播
Embedding X → embedding 梯度回传到 embedding 权重
MHA (Attention) Q/K/V → attention → 输出 梯度回传到 Q/K/V 权重,算梯度时可能需要前向激活
FFN Linear → Activation → Linear 梯度回传到两层线性层权重,Activation 需要前向值
LayerNorm 归一化 梯度回传到输入和归一化参数
Output / LM head Linear → LogSoftmax 梯度回传到最后线性层

总结一句话

Transformer 每一层的前向传播先算输出,反向传播在梯度链上倒着执行,每层都用前向激活来算梯度。


如果你想,我可以画一张Transformer Block 前向和反向传播的示意图,标出激活保存和 gradient checkpoint 的位置,让你直观理解“什么时候重新算一遍”。

图例

好的,我用文本示意图来展示一个 Transformer Block 的前向传播、反向传播,以及 gradient checkpoint 机制下的“重新算一遍”。


Transformer Block 文本示意

前向传播 (Forward Pass)
------------------------
Input│▼
[Multi-Head Attention]  ← 前向输出保存(activation)│▼
[Add & LayerNorm]       ← 前向输出保存(activation)│▼
[Feed Forward Network]│   ├─ Linear1│   ├─ Activation│   └─ Linear2           ← 前向输出保存▼
[Add & LayerNorm]       ← 前向输出保存│▼
Output (送入下一个 Block)

反向传播 (Backward Pass)
------------------------
Output gradient│▼
[Add & LayerNorm]       ← 使用前向激活计算梯度│▼
[Feed Forward Network]│   ├─ Linear2 gradient│   ├─ Activation gradient│   └─ Linear1 gradient▼
[Add & LayerNorm]│▼
[Multi-Head Attention]│   ├─ Output gradient│   └─ Q/K/V gradient▼
Input gradient

开启 Gradient Checkpoint 后(节省显存)

  • 假设只在 [Add&LayerNorm] 层存 checkpoint,FFN 和 MHA 不存中间激活
  • 反向传播时:
Input gradient│▼
[Multi-Head Attention]  ← 没存激活 → 重新执行前向计算得到中间激活,再算梯度│▼
[Add & LayerNorm]       ← checkpoint 激活存在,直接算梯度│▼
[Feed Forward Network]  ← 没存激活 → 重新执行前向计算得到激活,再算梯度│▼
[Add & LayerNorm]       ← checkpoint 激活存在│▼
Output gradient

关键点

  1. 没存激活的层 → 反向传播时需要重新算一次前向传播
  2. checkpoint 层 → 激活直接可用,不需要重新算
  3. 这样就用计算时间换显存空间

http://www.jsqmd.com/news/200914/

相关文章:

  • AI训练和推理到底需要什么样的配置?我的一次昂贵教训
  • 2025年十大高风险漏洞及其在实际攻击中的应用
  • 【Week3_Day11】【软件测试学习记录与反思】【TPshop项目的linux部署、整理思维导图、归纳遇到的问题、记录反思改进】
  • LLM 量化技术概述及 AWQ 和 GPTQ 介绍
  • 生成1.8万年气候数据,英伟达等提出长距离蒸馏,仅需单步计算实现长期天气预报
  • 【拯救HMI】工业HMI新手学习路径:30天系统化入门与实操蓝图
  • Web 常用的图片格式选择
  • 百度百舸面向百度天池超节点的大模型推理引擎优化,持续降低昆仑芯 XPU 的 token 成本
  • 【拯救HMI】HMI信息架构设计:四层金字塔模型——构建符合认知负荷的高效界面
  • 九氚汇领衔:2026年五大主流CRM系统最新排名深度解析与选型指南
  • 一位教师的使用分享:我是如何借助AI工具高效完成年终总结PPT的
  • 1.44 NoteBookLM使用指南:Google的AI笔记工具,让文档变成智能助手
  • 2026爆火AI论文神器限时公开:9款一键生成覆盖毕业期刊职称
  • 1.45 Embedding模型选择指南:文本向量化,如何选择最适合的模型
  • 口碑好的煤矿水仓清淤供应商
  • 【GNSS信号处理】多系统GNSS实时PPP(精密单点定位)解算MATLAB代码,支持 GPS、GLONASS、Galileo、北斗系统,集成了 SSR 轨道钟差、电离层 对流层改正、卫星码偏差
  • 便秘救星!可溶性VS不溶性膳食纤维,你吃对了吗?
  • 煤矿水仓清淤哪个好
  • 【路径规划】基于目标偏置高斯分布RRT算法实现机器人路径规划附matlab代码
  • 使用VIRobotics VI Generator轻松在LabVIEW中生成数学曲线
  • HR搭建薪酬体系,该优先公平还是激励?
  • 0x3f第22天复习 (8:50-10:10)(16:30-17.06)
  • react组件外的变量是共用的
  • 永久免费HTTPS证书申请教学
  • 从“零”开始,推演出CPU核心部件的诞生过程。
  • protues仿真软件操作的那篇及输出内容
  • CPU中的逻辑单元、存储单元的介绍
  • 通过按钮改变引脚的电平的状态并输出虚拟终端
  • markdown文件在vue网页上正确显示的方式(marked + DOMPurify)
  • 鸿蒙生态新篇章:从手机到电脑的全景升级