当前位置: 首页 > news >正文

CANN/cann-bench GQA算子API描述

GQA 算子 API 描述

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

1. 算子简介

分组查询注意力 (Grouped Query Attention) 算子,多个 query head 共享一组 key/value head,对已分头的 Q/K/V 执行注意力计算,在保持模型质量的同时显著减少 KV cache 内存占用和推理计算量。

主要应用场景

  • 大语言模型推理中的高效注意力计算(如 LLaMA-2 70B、Mistral)
  • 长序列推理场景中降低 KV cache 内存开销
  • 需要在模型质量和推理效率之间平衡的 Transformer 架构

算子特征

  • 难度等级:L4(FusedComposite)
  • 多输入(query, key, value)单输出,执行分组缩放点积注意力
  • 输入为已分头的张量,不包含 QKV 投影和输出投影步骤
  • N_q 必须能被 N_kv 整除,每个 KV head 被 N_q/N_kv 个 query head 共享

2. 算子定义

数学公式

对于第 $i$ 个 query head,使用第 $\lfloor i \times N_{kv} / N_q \rfloor$ 个 KV head:

$$ \text{head}i = \text{softmax}\left(Q_i \times K{g(i)}^T \times \text{scaleValue}\right) \times V_{g(i)} $$

其中:

  • $N_q$ 为 query 头数,$N_{kv}$ 为 KV 头数,$N_q$ 必须能被 $N_{kv}$ 整除
  • $g(i) = \lfloor i \times N_{kv} / N_q \rfloor$ 为第 $i$ 个 query head 对应的 KV head 索引
  • $D$ 为每个头的维度
  • $\text{scaleValue}$ 为缩放因子(<=0 时自动使用 $1/\sqrt{D}$)
  • 每个 KV head 被 $N_q / N_{kv}$ 个 query head 共享

具体子步骤:

  1. KV head 扩展:将每个 KV head 重复 $N_q / N_{kv}$ 次以匹配 query head 数
  2. 缩放点积:$\text{scores} = Q_i \times K_{g(i)}^T \times \text{scaleValue}$
  3. Softmax 归一化:$\text{attn_weights} = \text{softmax}(\text{scores}, \text{dim}=-1)$
  4. 加权求和:$y_i = \text{attn_weights} \times V_{g(i)}$

3. 接口规范

算子原型

cann_bench.gqa(Tensor query, Tensor key, Tensor value, float scaleValue=-1.0) -> Tensor y

输入参数说明

参数类型默认值描述
queryTensor必选查询张量(已分头),shape 为 [B, S, N_q, D]
keyTensor必选键张量(已分头),shape 为 [B, S_kv, N_kv, D]
valueTensor必选值张量(已分头),shape 为 [B, S_kv, N_kv, D]
scaleValuefloat-1.0缩放因子,<=0 时自动使用 1/sqrt(D)

输出

参数Shapedtype描述
y[B, S, N_q, D]与输入 query 相同分组查询注意力输出张量

数据类型

输入 dtype输出 dtype
float16float16
float32float32
bfloat16bfloat16

规则与约束

  • 所有输入 Tensor(query, key, value)的 dtype 必须一致
  • query的 shape 为 [B, S, N_q, D],keyvalue的 shape 为 [B, S_kv, N_kv, D]
  • N_q 必须能被 N_kv 整除,分组比 G = N_q / N_kv
  • 当 N_kv == N_q 时退化为标准多头注意力 (MHA)
  • 当 N_kv == 1 时退化为多查询注意力 (MQA)
  • scaleValue通常设置为 $1/\sqrt{D}$,当 <= 0 时自动使用该值

4. 精度要求

采用生态算子精度标准进行验证。

误差指标

  1. 平均相对误差(MERE):采样点中相对误差平均值

    $$ \text{MERE} = \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

  2. 最大相对误差(MARE):采样点中相对误差最大值

    $$ \text{MARE} = \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

通过标准

数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2
通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2

当平均相对误差 MERE < Threshold,最大相对误差 MARE < 10 * Threshold 时判定为通过。

5. 标准 Golden 代码

import torch """ GQA算子Torch Golden参考实现 分组查询注意力 (Grouped Query Attention),多个 query head 共享一组 KV head 公式: 扩展 KV heads 匹配 Q heads,y = softmax(Q @ K^T * scaleValue) @ V """ def gqa( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scaleValue: float = -1.0, ) -> torch.Tensor: """ 分组查询注意力 (Grouped Query Attention) Args: query: 查询张量 [B, S, N_q, D](已分头) key: 键张量 [B, S_kv, N_kv, D](已分头) value: 值张量 [B, S_kv, N_kv, D](已分头) scaleValue: 缩放因子,<=0 时自动使用 1/sqrt(D) Returns: 输出张量 [B, S, N_q, D] """ B, S, N_q, D = query.shape S_kv = key.shape[1] N_kv = key.shape[2] if scaleValue <= 0: scaleValue = 1.0 / (D ** 0.5) # 扩展 KV heads 以匹配 Q heads G = N_q // N_kv key = key.unsqueeze(3).expand(B, S_kv, N_kv, G, D).reshape(B, S_kv, N_q, D) value = value.unsqueeze(3).expand(B, S_kv, N_kv, G, D).reshape(B, S_kv, N_q, D) # 转置为 [B, N_q, S, D] q = query.transpose(1, 2) k = key.transpose(1, 2) v = value.transpose(1, 2) # 缩放点积注意力 scores = torch.matmul(q, k.transpose(-2, -1)) * scaleValue attn_weights = torch.nn.functional.softmax(scores, dim=-1) attn_output = torch.matmul(attn_weights, v) # 转回 [B, S, N_q, D] return attn_output.transpose(1, 2)

6. 额外信息

算子调用示例

import torch import cann_bench B, S, S_kv, D = 2, 128, 128, 128 N_q, N_kv = 32, 8 query = torch.randn(B, S, N_q, D, dtype=torch.float16, device="npu") key = torch.randn(B, S_kv, N_kv, D, dtype=torch.float16, device="npu") value = torch.randn(B, S_kv, N_kv, D, dtype=torch.float16, device="npu") y = cann_bench.gqa(query, key, value, scaleValue=-1.0)

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

http://www.jsqmd.com/news/783745/

相关文章:

  • 微信AI机器人插件生态全解析:从选型部署到开发实践
  • CANN/sip ColwiseMul按列逐点乘示例
  • 网盘下载提速神器:九大平台直链解析工具完整指南
  • Cursor API本地代理:内网集成AI编程与自动化工作流实战
  • 认知科学启发的AGI测试框架:从人类智能维度到可量化评估
  • HoRain云--PHP命名空间终极指南
  • pypto.distributed 模块介绍
  • Python后台服务/守护进程如何正确处理SIGINT信号?一个真实的生产环境案例
  • CANN/pyasc load_data数据加载API文档
  • 人形机器人供应链观察:良质关节如何在三年内成为头部厂商的核心合作伙伴?(附数字化案例拆解) - 黑湖科技老黑
  • CANN具身智能-PI0训练样例
  • HIXL LLM-DataDist接口
  • C++ ONNX Runtime 实战:为什么我的 session->Run 在跨函数调用时就崩溃了?
  • CANN/AMCT OFMR大模型量化
  • OpenClaw爬虫框架实战:从Awesome清单到自动化数据采集系统构建
  • 国内主流氯化镁生产厂家综合实力排行及选型指南 - 奔跑123
  • ngx_close_accepted_connection
  • 别再画丑图了!用Mermaid的gitGraph在Markdown里画专业Git分支图(附VSCode插件配置)
  • 基于OpenClaw构建多AI智能体协作平台:从数字生命蒸馏到理想国决策
  • 告别粘连字符!用Halcon的partition_dynamic算子精准分割OCR区域(附完整代码)
  • AI音乐生成技术解析:从符号与音频生成到混合模型实战
  • 向量引擎、deepseek v4、GPT Image 2、api key:Agent 时代最值钱的不是模型,是会调度的人
  • 外资阀门品牌2026市场介绍:米勒(Miller) - 米勒阀门
  • 基于微环谐振器的光子AI推理加速器:原理、设计与挑战
  • CANN算子测试竞赛中山大学软工小队提交
  • CANN/pypto lt函数API文档
  • 如何免费获取网盘高速下载:LinkSwift 九大平台直链解析终极指南
  • AI水下目标检测:从传统图像处理到深度学习部署实战
  • 工业盐技术选型指南:优质厂家的核心筛选维度 - 奔跑123
  • 别再只会用ref_table了!ABAP ALV里给自定义字段加F4搜索帮助的完整流程(附代码)