CANN/ops-transformer Floyd注意力梯度算子
FusedFloydAttentionGrad
【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
算子功能:训练场景下,计算Floyd注意力的反向输出,FloydAttn相较于传统FA主要是计算qk/pv注意力时会额外将seq作为batch轴从而转换为batchMatmul。
计算公式:
已知注意力的正向计算公式为:
$$ P=Softmax(Mask(scale*(QK_1^T + QK_2^T), atten_mask)) \ Y=(PV_1+PV_2) $$
则注意力的反向计算公式为:
$$ S=Softmax(S) $$
$$ dV_1=P^TdY $$
$$ dV_2=P^TdY $$
$$ dQ=\frac{((dS)*K_1)}{\sqrt{d}}+\frac{((dS)*K_2)}{\sqrt{d}} $$
$$ dK_1=\frac{((dS)^T*Q)}{\sqrt{d}} $$
$$ dK_2=\frac{((dS)^T*Q)}{\sqrt{d}} $$
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | 公式中的输入Q。 | FLOAT16、BFLOAT16 | ND |
| key1 | 输入 | 公式中的输入K1。 | FLOAT16、BFLOAT16 | ND |
| value1 | 输入 | 公式中的输入V1。 | FLOAT16、BFLOAT16 | ND |
| key2 | 输入 | 公式中的输入K2。 | FLOAT16、BFLOAT16 | ND |
| value2 | 输入 | 公式中的输入V2。 | FLOAT16、BFLOAT16 | ND |
| dy | 输入 | 公式中的输入dY。 | FLOAT16、BFLOAT16 | ND |
| attenMaskOptional | 可选输入 | 公式中的atten_mask,表示注意力掩码,取值为1代表该位不参与计算(不生效),为0代表该位参与计算。 | BOOL、UINT8 | ND |
| scaleValue | 可选属性 |
| DOUBLE | - |
| dqOut | 输出 | 公式中的dQ,表示query的梯度。 | FLOAT16、BFLOAT16 | ND |
| dk1Out | 输出 | 公式中的dK1,表示key1的梯度。 | FLOAT16、BFLOAT16 | ND |
| dv1Out | 输出 | 公式中的dV1,表示value1的梯度。 | FLOAT16、BFLOAT16 | ND |
| dk2Out | 输出 | 公式中的dK2,表示key2的梯度。 | FLOAT16、BFLOAT16 | ND |
| dv2Out | 输出 | 公式中的dV2,表示value2的梯度。 | FLOAT16、BFLOAT16 | ND |
约束说明
该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配
关于数据shape的约束,其中:
- B:取值范围为1~2K。
- H:取值范围为1~256。
- N:取值范围为16~1M且N%16==0。
- M:取值范围为128~1M且M%128==0。
- K:取值范围为128~1M且K%128==0。
- D:取值范围为32/64/128。
query与key1的第0/2/4轴需相同。
key1与value1 shape需相同。
key2与value2 shape需相同。
query与dy/attentionIn shape需相同。
softmaxMax与softmaxSum shape需相同。
D只支持32/64/128。
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_fused_floyd_attention_grad | 通过接口方式调用aclnnFusedFloydAttentionGrad算子。 |
【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
