CANN稀疏注意力元数据算子文档
SparseAttnSharedkvMetadata
【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法,提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 推理系列产品 | √ |
功能说明
- API功能:
SparseAttnSharedkvMetadata算子旨在生成一个任务列表,包含每个AIcore的Attention计算任务的起止点的Batch、Head、以及 Q 和 K 的分块的索引,供后续SparseAttnSharedkv算子使用。
函数原型
custom.npu_sparse_attn_sharedkv_metadata(num_heads_q, num_heads_kv, head_dim, *, cu_seqlens_q=None, cu_seqlens_ori_kv=None, cu_seqlens_cmp_kv=None, seqused_q=None, seqused_kv=None, batch_size=0, max_seqlen_q=0, max_seqlen_kv=0, ori_topk=0, cmp_topk=0, cmp_ratio=-1, ori_mask_mode=4, cmp_mask_mode=3, ori_win_left=127, ori_win_right=0, layout_q='BSND', layout_kv='PA_ND', has_ori_kv=True, has_cmp_kv=True, device='npu:0') -> Tensor- Transformer SparseAttnSharedkvMetadata 算子实现参考: SparseAttnSharedkvMetadata
参数说明
num_heads_q(
int):必选参数,表示公式中的$Q$的多头数,目前仅支持64。num_heads_kv(
int):必选参数,表示公式中的$\tilde{K}$和$\tilde{V}$的多头数,目前仅支持1。head_dim(
int):必选参数,表示注意力头的维度。*:必选参数,代表其之前的变量是位置相关的,必须按照顺序输入;之后的变量是可选参数,位置无关,需要使用键值对赋值,不赋值会使用默认值。
cu_seqlens_q(
Tensor):可选参数,当layout_query为TND时,表示不同Batch中q的有效token数,维度为B+1,大小为参数中每个元素的值表示目前batch与之前所有batch的token数总和,即前缀和,数据类型支持int32。cu_seqlens_ori_kv(
Tensor):可选参数,当layout_kv为TND时,表示不同Batch中ori_kv的有效token数,维度为B+1,大小为参数中每个元素的值表示目前batch与之前所有batch的token数总和,即前缀和,数据类型支持int32。目前layout_kv仅支持PA_ND,故设置此参数无效。cu_seqlens_cmp_kv(
Tensor):可选参数,当layout_kv为TND时,表示不同Batch中cmp_kv的有效token数,维度为B+1,大小为参数中每个元素的值表示目前batch与之前所有batch的token数总和,即前缀和,数据类型支持int32。目前layout_kv仅支持PA_ND,故设置此参数无效。seqused_q(
Tensor):可选参数,表示不同Batch中q实际参与运算的token数,维度为B,数据格式支持ND,数据类型支持int32,不输入则所有token均参与运算。目前暂不支持指定该参数。seqused_kv(
Tensor):可选参数,表示不同Batch中ori_kv实际参与运算的token数,维度为B,数据格式支持ND,数据类型支持int32,不输入则所有token均参与运算。batch_size(
int):可选参数,表示输入样本批量大小,默认值为None。max_seqlen_q(
int):可选参数,表示所有batch中q的最大有效token数。max_seqlen_kv(
int):可选参数,表示所有batch中ori_kv的最大有效token数。ori_topk(
int):可选参数,表示通过QLI算法从ori_kv中筛选出的关键稀疏token的个数。目前暂不支持指定该参数,默认值为None。cmp_topk(
int):可选参数,表示通过QLI算法从cmp_kv中筛选出的关键稀疏token的个数,目前仅支持512,默认值为None。cmp_ratio(
int):可选参数,表示对ori_kv的压缩率,数据范围支持4/128,默认值为None。ori_mask_mode(
int):可选参数,表示q和ori_kv计算的mask模式,目前仅支持输入默认值4,代表band模式的mask。cmp_mask_mode(
int):可选参数,表示q和cmp_kv计算的mask模式,目前仅支持输入默认值3,代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景。ori_win_left(
int):可选参数,表示q和ori_kv计算中q对过去token计算的数量,目前仅支持默认值127。ori_win_right(
int):可选参数,表示q和ori_kv计算中q对未来token计算的数量,目前仅支持默认值0。layout_q(
str):可选参数,表示输入q的数据排布格式,默认值为BSND,目前支持传入BSND和TND。layout_kv(
str):可选参数,表示输入ori_kv和cmp_kv的数据排布格式,目前仅支持传入默认值PA_ND(PageAttention)。has_ori_kv(
bool):可选参数,表示是否传入ori_kv,默认值为true。has_cmp_kv(
bool):可选参数,表示是否传入cmp_kv,默认值为true。device(
str):可选参数,用于获取设备信息,默认值为None。
返回值说明
- metadata(
Tensor):每个cube核上FlashAttention计算任务的Batch、Head、以及 Q 和 K 的分块的索引,以及每个vector核上FlashDecode的规约任务索引。
约束说明
- 该接口支持推理场景下使用。
- 该接口支持aclgraph模式。
- Tensor不能全传None。
Atlas A3 推理系列产品 调用示例
- 支持单算子模式调用和aclgraph模式调用,作为SparseAttnSharedkv算子的前序算子,调用示例见SparseAttnSharedkv调用示例。
【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法,提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
