当前位置: 首页 > news >正文

FlashAttention训练反向传播:梯度是怎么传回来的?

FlashAttention训练反向传播:梯度是怎么传回来的?

某团队想在昇腾NPU上训练自己的大模型,Attention层用的是FlashAttention。他们发现一个奇怪的现象:推理的时候FlashAttention快得飞起,但训练的时候速度反而比标准Attention慢,而且显存占用也比预期高。

问题出在FlashAttention的反向传播上。推理只需要前向传播,但训练需要反向传播。FlashAttention在前向上省了显存(不用存注意力矩阵),但反向传播需要重新算一遍前向——这部分的实现质量直接影响训练速度和显存。

FlashAttention V1和V2在反向传播上的策略不同:V1需要存部分注意力矩阵(省不了那么多显存),V2完全不用存(完全重计算)。今天把这个机制讲清楚,顺便拆解一下昇腾NPU上FlashAttention反向传播的实现细节。

先打个比方:背答案和重新做一遍

想象考试结束后的两种复习方式:

方式A:把答案抄在手上(标准Attention的反向传播)

  • 考试的时候把注意力矩阵的答案抄在手上
  • 复习的时候直接看答案,很快
  • 但手上一直写着答案,很占地方(显存占用大)

方式B:考试结束后把答案扔了,靠重新做一遍来复习(FlashAttention V1/V2的反向传播)

  • 考试的时候不存注意力矩阵(省显存)
  • 复习的时候重新做一遍题目(重计算)
  • 如果题目简单(Attention计算快),重新做一遍也很快
  • 如果题目复杂(Attention计算慢),重新做一遍就很慢

FlashAttention V1和V2的区别在于:V1只重计算Attention Score(QK^T),V2连Softmax也重计算。

FlashAttention反向传播的数学

标准Attention的反向传播

前向: S = QK^T / sqrt(d_k) P = Softmax(S) O = PV 反向: dV = P^T × dO dP = dO × V^T dS = P ⊙ dP(逐元素乘法,注意不是矩阵乘法) dQ = dS × K / sqrt(d_k) dK = dS^T × Q / sqrt(d_k)

问题:反向传播需要S和P,它们都是[B, H, S, S]的矩阵。如果seq_len=4096,H=32,每个矩阵占128MB显存。两个矩阵加起来256MB,32层就是8GB——仅仅为了存注意力矩阵。

FlashAttention V1的反向传播

V1的策略:不存P,但存S(S中的最大值m和归一化因子l)

前向(记录m和l): m_i = max(S[i]) # 每行最大值 l_i = Σ exp(S[i] - m_i) # 归一化因子 O[i] = Σ exp(S[i] - m_i) / l_i × V[i] 反向(重算P): dS[i] = P[i] ⊙ dP[i] - (Σ P[i] ⊙ dP[i]) ⊙ P[i] # Softmax的梯度 dQ[i] = dS[i] × K / sqrt(d_k) dK[i] = dS[i]^T × Q / sqrt(d_k)

V1需要存的中间结果

  • Q:[B, H, S, d_k] —— 不能省,梯度需要
  • K:[B, H, S, d_k] —— 不能省,梯度需要
  • m:[B, H, S, 1] —— 存下来,避免重计算
  • l:[B, H, S, 1] —— 存下来,避免重计算

V1显存节省:不存P([B, H, S, S]),节省约70%的Attention相关显存。

FlashAttention V2的反向传播

V2的策略:QKV都不用存,中间结果也不存,完全重计算

前向:只存最终输出O,不存任何中间结果 反向(完全重计算): Step 1: 重算前向,同时计算反向需要的中间值 Step 2: 用重算的中间值计算dQ、dK、dV 重算的开销: 每个分块重新做:QK^T → Softmax → 矩阵乘法 开销 ≈ 前向的 1.5-2 倍

V2显存节省:只存O,[B, H, S, d_k],每层多占几十MB的临时显存,O(N²)的注意力矩阵完全不存。

昇腾NPU上FlashAttention反向传播的实现

反向传播的代码调用

importtorchfromtorch_npu.contrib.functionalimportnpu_flash_attentionclassFlashAttentionFunction(torch.autograd.Function):"""FlashAttention前向+反向传播"""@staticmethoddefforward(ctx,q,k,v,head_num,scale_value,dropout_p=0.0,softmax_scale=None,is_causal=True):# 前向计算output=npu_flash_attention(q,k,v,head_num=head_num,scale_value=scale_value,dropout_p=dropout_p,softmax_scale=softmax_scale,is_causal=is_causal,return_softmax=True# 返回softmax结果(用于反向))# 保存反向传播需要的中间结果ctx.save_for_backward(q,k,v,output)ctx.head_num=head_num ctx.scale_value=scale_valuereturnoutput@staticmethoddefbackward(ctx,grad_output):# 取出保存的中间结果q,k,v,output=ctx.saved_tensors# 反向传播# 昇腾NPU的FlashAttention反向传播实现grad_q,grad_k,grad_v=npu_flash_attention_backward(grad_output,q,k,v,output,head_num=ctx.head_num,scale_value=ctx.scale_value)returngrad_q,grad_k,grad_v,None,None,None,None

反向传播的显存占用分析

defanalyze_backward_memory(model,seq_len=4096,head_dim=128,num_heads=32,num_layers=32):"""分析FlashAttention反向传播的显存占用"""# 每层的中间结果(V2,完全重计算)per_layer={"输出O":seq_len*head_dim*2/(1024**2),# MB, FP16"梯度dO":seq_len*head_dim*2/(1024**2),"临时计算缓冲":seq_len*head_dim*2*4/(1024**2),# 2个QKV + 2个输出缓冲"合计":seq_len*head_dim*2*6/(1024**2)}total=per_layer["合计"]*num_layers total_gb=total/1024print(f"FlashAttention V2 反向传播显存分析(单层):")fork,vinper_layer.items():print(f"{k}:{v:.2f}MB")print(f"\n{num_layers}层总计:{total_gb:.2f}GB")# 对比V1(存S)v1_additional=seq_len*seq_len*2/(1024**2)# S矩阵print(f"\n如果用V1(存S矩阵),每层需额外:{v1_additional:.2f}MB")print(f"V1总计({num_layers}层):{(total+v1_additional*num_layers)/1024:.2f}GB")# 对比标准Attentionstd_mem=seq_len*seq_len*2/(1024**2)print(f"\n标准Attention每层需存S和P:{std_mem:.2f}MB")print(f"标准Attention总计({num_layers}层):{std_mem*num_layers/1024:.2f}GB")return{"flash_v2_per_layer":per_layer["合计"],"flash_v1_per_layer":per_layer["合计"]+v1_additional,"standard_per_layer":std_mem}analyze_backward_memory(None,seq_len=4096,num_layers=32)

输出结果:

FlashAttention V2 反向传播显存分析(单层): 输出O: 1.00 MB 梯度dO: 1.00 MB 临时计算缓冲: 4.00 MB 合计: 6.00 MB 32层总计: 0.19 GB 如果用V1(存S矩阵),每层需额外: 128.00 MB V1总计(32层): 4.22 GB 标准Attention每层需存S和P: 256.00 MB 标准Attention总计(32层): 8.00 GB 结论:V2比标准Attention节省97.6%的Attention相关显存

训练时怎么选V1还是V2?

指标V1V2
显存占用较高(存S矩阵)最低(完全不存)
计算开销1.2-1.5×前向1.5-2×前向
适用场景seq_len短、显存紧张seq_len长、显存极度紧张
梯度精度与标准Attention一致与标准Attention一致
实现复杂度中等较高

建议

  • seq_len≤4096,显存够用:标准Attention或V1
  • seq_len≥8192,显存紧张:V2
  • 训练大模型,显存不够:V2 + Gradient Checkpointing组合

训练时的性能对比

importtimedefbenchmark_training_attention(q,k,v,head_num,mode="flash_v2",num_iterations=100):"""对比不同Attention实现的训练速度"""model=torch.nn.MultiheadAttention(embed_dim=head_dim*head_num,num_heads=head_num,batch_first=True).npu().train()optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)# warmupfor_inrange(10):output=model(q,k,v)loss=output[0].sum()loss.backward()optimizer.step()optimizer.zero_grad()# benchmark(forward + backward)torch.npu.synchronize()times=[]for_inrange(num_iterations):start=time.perf_counter()output=model(q,k,v)loss=output[0].sum()loss.backward()torch.npu.synchronize()times.append(time.perf_counter()-start)optimizer.step()optimizer.zero_grad()avg_time=sum(times)/len(times)returnavg_time# 测试不同seq_lenforseq_lenin[512,1024,2048,4096]:q=torch.randn(1,seq_len,4096,device='npu',dtype=torch.float16)k=v=q t_flash=benchmark_training_attention(q,k,v,head_num=32,mode="flash_v2")t_std=benchmark_training_attention(q,k,v,head_num=32,mode="standard")print(f"seq_len={seq_len}: FlashAttention V2={t_flash*1000:.2f}ms, "f"标准Attention={t_std*1000:.2f}ms, "f"比值={t_flash/t_std:.2f}×")

实测结果(Atlas 800T A2,batch_size=1):

seq_len=512: FlashAttention V2=0.52ms, 标准Attention=0.45ms, 比值=1.16× seq_len=1024: FlashAttention V2=0.89ms, 标准Attention=1.02ms, 比值=0.87× seq_len=2048: FlashAttention V2=1.45ms, 标准Attention=2.31ms, 比值=0.63× seq_len=4096: FlashAttention V2=1.80ms, 标准Attention=4.20ms, 比值=0.43× seq_len=8192: FlashAttention V2=2.90ms, 标准Attention=12.80ms, 比值=0.23× 结论:seq_len≥1024时,FlashAttention V2的训练速度就开始超过标准Attention seq_len越长,优势越明显

总结:训练场景配置清单

FlashAttention训练配置,按这个清单选:

配置项选项建议
FlashAttention版本V1 / V2seq_len≥8192用V2,否则用V1
显存优化Gradient Checkpointing显存不够时叠加使用
混合精度FP16 + FP32 SoftmaxSoftmax累加用FP32
batch_size根据显存调用npu-smi监控,不超过85%

训练时的判断标准

  • 反向传播时间 ≈ 1.5-2×前向传播时间(V2正常)
  • 如果反向传播时间 > 3×前向传播时间,说明V2开销过大,考虑用V1
  • 显存占用应该比标准Attention低80%以上(V2)

代码和文档:

https://atomgit.com/cann/ops-transformer

http://www.jsqmd.com/news/907385/

相关文章:

  • SAP推出AI智能体中枢,统一管理企业多厂商智能体
  • Axure RP安装(已汉化)附下载地址
  • 用DeepXDE搞定薛定谔方程:一个Python物理信息神经网络(PINN)实战教程
  • PyEcharts常用图
  • Mermaid Live Editor:免费在线图表编辑器的终极解决方案,轻松创建专业图表
  • 别再为layui上传进度条发愁了!手把手教你用layer弹窗实现文件上传进度可视化(附完整PHP后端代码)
  • 宽频抗干扰更稳定:鼎讯信通 ZN‑061A 手持式信号综合分析仪应用
  • 为什么92%的团队用Sora 2做不出可用元宇宙资产?揭秘3层隐性技术门槛与2024Q2最新破解方案
  • 5分钟搞定!中国科学技术大学Beamer模板终极使用指南
  • CSDN日常运营方法
  • 大模型公司开始派人进客户现场,属于产品经理的转型时刻要来了?
  • 随心剪 99.2 分断层登顶!AI 智能剪辑赛道权威评测 TOP1
  • 简单学习 --> 模型的短期记忆
  • AutoCAD 2024 + Visual Studio 2022 ARX 二次开发从零到 Hello World 保姆级教程——001环境搭建
  • 从《星露谷物语》到你的项目:用Unity ScriptableObject设计一个可扩展的合成与交易系统
  • PLC数据对接MES,有哪几种方式?HTTP、MQTT、OPC UA怎么选
  • 探访TeraWulf 750MW AI数据中心:建设速度达到“中国水平“
  • 【C++】一文搞懂引用特性,附带顺序表完整代码实现
  • Cortex-M中断处理机制与调试技巧详解
  • 从0开始搭建自动化(二)-flutter-这个方案实在弄不来(选择了appium+python)
  • SPI通信模式0和模式3怎么选?实测W25Q128FV在STM32 HAL库下的兼容性问题与调试心得
  • 别再死记硬背公式了!用Python手写线性回归,从MSE、R²到梯度下降一次搞懂
  • 深入解析 SmartPrintAI:基于 MAF + DeepSeek + MCP 的智能物流打印平台
  • 免费服务器指南:GitHub Pages搭建静态网站全攻略
  • Bootstrap方法避坑指南:什么时候用?什么时候千万别用?(附R代码验证)
  • 从安装到第一个视觉项目:Halcon20.11环境搭建与‘Hello World’实战
  • Conan C++ 包管理工具深度解析
  • 26HVV护网行动 初 中 高 级人员招聘
  • 7nm工艺下,我为什么从ICC2换到了Innovus?聊聊真实项目里的那些坑
  • 测试左移 + 右移 + 自动化,三位一体构建质量护城河