CANN8.5-ops-transformer更新了什么昇腾NPU算子
CANN 8.5 在 2024 Q4 发布,ops-transformer 仓库跟进了三个重要更新:FlashAttention V2 的反向传播融合、MC2 通算融合的多卡拓扑适配、以及新增的 GroupedAttention 算子。如果你已经用 CANN 8.0 在跑大模型推理,这篇帮你判断要不要升级。
FlashAttention V2:反向传播终于融合了
CANN 8.0 的 FlashAttention 只融合了前向传播。训练场景下,反向传播还是要拆成三个独立 kernel(dQ、dK、dV),中间结果落显存。CANN 8.5 把反向传播也融合进去了——一次 kernel 完成 dQ/dK/dV 的计算。
训练场景的收益:
| 配置 | 训练吞吐 (tokens/s/p) | 显存占用 |
|---|---|---|
| CANN 8.0 FlashAttention V1 | 1,820 | 56 GB |
| CANN 8.5 FlashAttention V2 | 2,410 | 44 GB |
训练吞吐提升 32%,显存省了 21%。显存省下来意味着可以开更大的 batch 或更长的序列。
前向推理不受影响——如果你只做推理,这个更新对你没用,不用特意升级。
MC2 通算融合:多卡拓扑适配
CANN 8.0 的 MC2 只支持同一台服务器内的卡间通信。8 卡的 Atlas 800I A2 没问题,但如果你要用两台服务器做 16 卡的 MoE 训练,跨机 All-to-All 通信走的是 RoCE,MC2 没法把它和计算重叠。
CANN 8.5 加了 RoCE 通算融合支持。MC2 可以同时管理 HCCL 的卡间通信和 RoCE 的跨机通信,让两者都跟计算流水线重叠。
实测数据,Mixtral 8x7B 的 16 卡训练:
| 配置 | 通信占比 | 吞吐 (tokens/s/p) |
|---|---|---|
| CANN 8.0(跨机不通算融合) | 42% | 680 |
| CANN 8.5(跨机通算融合) | 23% | 1,050 |
通信占比从 42% 降到 23%,跨机场景的 MC2 终于能用了。
新增:GroupedAttention 算子
Grouped-Query Attention(GQA)在 Llama2、Mistral 等模型里广泛使用。CANN 8.0 需要把 GQA 展开 MHA 来跑,CANN 8.5 新增了原生 GQA 支持:
importtorch_npu# GQA: num_q_heads=32, num_kv_heads=8q=torch.randn(1,32,4096,128,device="npu",dtype=torch.float16)k=torch.randn(1,8,4096,128,device="npu",dtype=torch.float16)v=torch.randn(1,8,4096,128,device="npu",dtype=torch.float16)# CANN 8.5 直接支持 KV heads < Q headsout=torch_npu.npu.flash_attention(q,k,v)CANN 8.0 要手动把 K/V repeat 到 32 个 head,显存和计算都浪费。原生 GQA 省掉了 repeat 操作,显存占用降低 75%,延迟降 15-20%。
升级建议
| 场景 | 是否建议升级到 8.5 |
|---|---|
| 只做推理(单机) | 不急,8.0 够用 |
| 推理 + GQA 模型 | 建议升级,原生 GQA 收益大 |
| 训练(单机) | 建议升级,FlashAttention V2 反向融合省显存 |
| 训练(多机 MoE) | 必须升级,跨机 MC2 是刚需 |
升级方式:
# 更新 CANN toolkit./Ascend-cann-toolkit_8.5.run--install# 重新编译 ops-transformercdops-transformer&&gitpull&&bashbuild.sh# 更新 torch_npupipinstalltorch_npu==2.3.0# CANN 8.5 对应版本兼容性注意
CANN 8.5 的 FlashAttention V2 API 跟 8.0 的 V1 有个不兼容变更:npu.flash_attention的scale参数从位置参数改成了关键字参数。如果你之前的代码是flash_attention(q, k, v, 1.0/math.sqrt(dim)),需要改成flash_attention(q, k, v, scale=1.0/math.sqrt(dim))。不改的话会报参数类型错误。
如果你的 MoE 训练要上多机,CANN 8.5 的跨机 MC2 是硬需求,不升级就是浪费卡。单机推理用户可以观望,等下一个大版本再看。仓库在这里:
https://atomgit.com/cann/ops-transformer
