当前位置: 首页 > news >正文

CANN/ops-nn转置量化矩阵乘

aclnnTransposeQuantBatchMatMul

【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn

📄 查看源码

产品支持情况

产品是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品×
Atlas A2 训练系列产品/Atlas A2 推理系列产品×
Atlas 200I/500 A2 推理产品×
Atlas 推理系列产品×
Atlas 训练系列产品×

功能说明

  • 接口功能:完成张量x1与张量x2量化的矩阵乘计算,支持K-C、MX量化模式。仅支持三维的Tensor传入。Tensor支持转置,转置序列根据传入的数列进行变更。permX1代表张量x1的转置序列,permX2代表张量x2的转置序列,序列值为0的是batch维度,其余两个维度做矩阵乘法。

  • 示例: 假设x1的shape是(M, B, K),x2的shape是(B, K, N),x1Scale和x2Scale不为None,batchSplitFactor等于1时,计算输出out的shape是(M, B, N)。

函数原型

每个算子分为两段式接口,必须先调用“aclnnTransposeQuantBatchMatMulGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnTransposeQuantBatchMatMul”接口执行计算。

aclnnStatus aclnnTransposeQuantBatchMatMulGetWorkspaceSize( const aclTensor* x1, const aclTensor* x2, const aclTensor* bias, const aclTensor* x1Scale, const aclTensor* x2Scale, const int32_t dtype, const int64_t groupSize, const aclIntArray* permX1, const aclIntArray* permX2, const aclIntArray* permY, const int32_t batchSplitFactor, aclTensor* out, uint64_t* workspaceSize, aclOpExecutor** executor)
aclnnStatus aclnnTransposeQuantBatchMatMul( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, const aclrtStream stream)

aclnnTransposeQuantBatchMatMulGetWorkSpaceSize

  • 参数说明:

    参数名输入/输出描述使用说明数据类型数据格式维度(shape)非连续Tensor
    x1(aclTensor*)输入表示矩阵乘的第一个矩阵。
    • 数据类型需要与x2满足数据类型推导规则(参见互推导关系和约束说明)。
    • 数据类型当前仅支持FLOAT8_E5M2、FLOAT8_E4M3FN。
    FLOAT8_E5M2、FLOAT8_E4M3FNND3
    x2(aclTensor*)输入表示矩阵乘的第二个矩阵。
    • 数据类型需要与x1满足数据类型推导规则(参见互推导关系和约束说明)。
    • x2的k维度需要与x1的k维度大小相等。
    • 数据类型当前仅支持FLOAT8_E5M2、FLOAT8_E4M3FN。
    FLOAT8_E5M2、FLOAT8_E4M3FNND3
    bias(aclTensor*)输入表示矩阵乘的偏置矩阵。预留参数,当前暂不支持。----
    x1Scale(aclTensor*)输入表示左矩阵的量化系数。
    • K-C量化场景,shape支持一维且需要等于[m]。
    • MX量化场景,shape支持四维。
    FLOAT32、FLOAT8_E8M0ND1
    x2Scale(aclTensor*)输入表示右矩阵的量化系数。
    • K-C量化场景,shape支持一维且需要等于[n]。
    • MX量化场景,shape支持四维。
    FLOAT32、FLOAT8_E8M0ND1
    dtype(int32_t)输入用于指定输出矩阵的数据类型,支持的值为:1、27。
    • 取值为1, 表示输出矩阵类型为FLOAT16。
    • 取值为27, 表示输出矩阵类型为BFLOAT16。
    INT32---
    groupSize(int64_t)输入用于输入m、n、k方向上的量化分组大小。由3个方向的groupSizeM,groupSizeN,groupSizeK 三个值拼接组成,每个值占16位,共占用int64_t类型groupSize的低48位(groupSize中的高16位的数值无效)INT64---
    permX1(aclIntArray*)输入表示矩阵乘的第一个矩阵的转置序列,host侧的aclIntArray。支持[1, 0, 2]。INT64-1-
    permX2(aclIntArray*)输入表示矩阵乘的第二个矩阵的转置序列,host侧的aclIntArray。
    • K-C量化场景,支持[0, 1, 2]。
    • MX量化场景,支持[0, 1, 2]或[0, 2, 1]。
    INT64-1-
    permY(aclIntArray*)输入表示矩阵乘输出矩阵的转置序列,host侧的aclIntArray。支持[1, 0, 2]。INT64-1-
    batchSplitFactor(int32_t)输入用于指定矩阵乘输出矩阵中B维的切分大小,Host侧的整型,当前仅支持取值为1。当前取值仅支持为1。INT32---
    out(aclTensor*)输出表示矩阵乘的输出矩阵,公式中的out。不支持空Tensor。BFLOAT16、FLOAT16ND3-
  • 返回值:

    aclnnStatus: 返回状态码,具体参见aclnn返回码。

    第一段接口完成入参校验,出现以下场景时报错:

    返回值错误码描述
    ACLNN_ERR_PARAM_NULLPTR161001传入的x1、x2、out、x1Scale、x2Scale、permX1、permX2、permY是空指针。
    ACLNN_ERR_PARAM_INVALID161002x1、x2、x1Scale、x2Scale或out的数据类型不在支持的范围内。
    x1、x2、x1Scale、x2Scale或out的shape不满足校验条件。
    x1、x2、permX1、permX2、permY的维度大小不等于3。
    batchSplitFactor不在支持的范围内
    permX1、permX2、permY的取值不在支持的范围内。

aclnnTransposeQuantBatchMatMul

  • 参数说明:

    参数名输入/输出描述
    workspace输入在Device侧申请的workspace内存地址。
    workspaceSize输入在Device侧申请的workspace大小,由第一段接口aclnnTransposeQuantBatchMatMulGetWorkspaceSize获取。
    executor输入op执行器,包含了算子计算流程。
    stream输入指定执行任务的stream。
  • 返回值:

    aclnnStatus: 返回状态码,具体参见aclnn返回码。

约束说明

  • 确定性说明: aclnnTransposeQuantBatchMatMul默认确定性实现。

  • Ascend 950PR/Ascend 950DT :

    • K-C量化场景,K仅支持512,N仅支持128。x1Scale和x2Scale仅支持1维,并且x1Scale要求shape为(M,), x2Scale要求shape为(N,),group_size仅支持配置为0,其他取值不生效。
    • MX量化场景,K仅支持64的倍数。 x1Scale和x2Scale仅支持4维,并且x1Scale要求shape为(M, B, K/64, 2), 当permX2为[0, 1, 2]时,x2Scale要求shape为(B, K/64, N, 2);当permX2为[0, 2, 1]时,x2Scale要求shape为(B, N, K/64, 2)。group_size的groupSizeM和groupSizeN仅支持0或1,groupSizeK仅支持32。
    • groupSize相关约束:
      • 仅在MX量化场景中生效。
      • 传入的groupSize内部会按如下公式分解得到groupSizeM、groupSizeN、groupSizeK,当其中有1个或多个为0,会根据x1/x2/x1Scale/x2Scale输入shape重新设置groupSizeM、groupSizeN、groupSizeK用于计算。原理:假设groupSizeM=0,表示M方向量化分组值由接口推断,推断公式为groupSizeM = M / scaleM(需保证M能被scaleM整除),其中M与x1 shape中的M一致,scaleM与x1Scale shape中的M一致。

    $$ groupSize = groupSizeK | groupSizeN << 16 | groupSizeM << 32 $$

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。

#include <iostream> #include <memory> #include <vector> #include <limits> #include <cmath> #include "acl/acl.h" #include "aclnnop/aclnn_transpose_quant_batch_mat_mul.h" #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define CHECK_FREE_RET(cond, return_expr) \ do { \ if (!(cond)) { \ Finalize(deviceId, stream); \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while (0) int64_t GetShapeSize(const std::vector<int64_t>& shape) { int64_t shapeSize = 1; for (auto i : shape) { shapeSize *= i; } return shapeSize; } int Init(int32_t deviceId, aclrtStream* stream) { // 固定写法,资源初始化 auto ret = aclInit(nullptr); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); ret = aclrtSetDevice(deviceId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); ret = aclrtCreateStream(stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); return 0; } // BF16 到 float 的转换函数 float bf16_to_float(uint16_t bf16) { uint16_t sign = (bf16 >> 15) & 0x1; uint16_t exp = (bf16 >> 7) & 0xFF; // 8 位指数 uint16_t mant = bf16 & 0x7F; // 特殊值处理 if (exp == 0) { if (mant == 0) { return sign ? -0.0f : 0.0f; } else { // 非规格化 BF16 -> float return (sign ? -1.0f : 1.0f) * (float)mant * (1.0f / (1 << 7) / std::ldexp(1.0, 126)); } } else if (exp == 255) { // 无穷大或 NaN if (mant == 0) { return sign ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity(); } else { return std::numeric_limits<float>::quiet_NaN(); } } else { // 规格化数 float f_exp = (float)(exp - 127); // 偏移 127 float f_mant = (float)mant / (1 << 7); // 7 位小数 float f = (sign ? -1.0f : 1.0f) * (1.0f + f_mant) * (1 << (int)f_exp); return f; } } template <typename T> int CreateAclTensor( const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr, aclDataType dataType, aclTensor** tensor) { auto size = GetShapeSize(shape) * sizeof(T); // 调用aclrtMalloc申请Device侧内存 auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); // 调用aclrtMemcpy将Host侧数据拷贝到Device侧内存上 ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); // 计算连续tensor的strides std::vector<int64_t> strides(shape.size(), 1); for (int64_t i = shape.size() - 2; i >= 0; i--) { strides[i] = shape[i + 1] * strides[i + 1]; } // 调用aclCreateTensor接口创建aclTensor *tensor = aclCreateTensor( shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } void Finalize(int32_t deviceId, aclrtStream stream) { aclrtDestroyStream(stream); aclrtResetDevice(deviceId); aclFinalize(); } int AclnnTransposeQuantBatchMatmulTest(int32_t deviceId, aclrtStream& stream) { auto ret = Init(deviceId, &stream); // check根据自己的需要处理 CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); // 2. 构造输入与输出,需要根据API的接口自定义构造 int32_t M = 32; int32_t K = 512; int32_t N = 128; int32_t Batch = 16; std::vector<int64_t> x1Shape = {M, Batch, K}; std::vector<int64_t> x2Shape = {Batch, K, N}; std::vector<int64_t> x1ScaleShape = {M}; std::vector<int64_t> x2ScaleShape = {N}; std::vector<int64_t> outShape = {M, Batch, N}; std::vector<int64_t> permX1Series = {1, 0, 2}; std::vector<int64_t> permX2Series = {0, 1, 2}; std::vector<int64_t> permYSeries = {1, 0, 2}; void* x1DeviceAddr = nullptr; void* x2DeviceAddr = nullptr; void* x1ScaleDeviceAddr = nullptr; void* x2ScaleDeviceAddr = nullptr; void* outDeviceAddr = nullptr; aclTensor* x1 = nullptr; aclTensor* x2 = nullptr; aclTensor* x1Scale = nullptr; aclTensor* x2Scale = nullptr; aclTensor* out = nullptr; std::vector<int8_t> x1HostData(GetShapeSize(x1Shape), 0x38); std::vector<int8_t> x2HostData(GetShapeSize(x2Shape), 0x38); std::vector<float> x1ScaleHostData(GetShapeSize(x1ScaleShape), 1); std::vector<float> x2ScaleHostData(GetShapeSize(x2ScaleShape), 1); std::vector<uint16_t> outHostData(GetShapeSize(outShape), 0); // bf16 // 创建x1 aclTensor ret = CreateAclTensor(x1HostData, x1Shape, &x1DeviceAddr, aclDataType::ACL_FLOAT8_E4M3FN, &x1); std::unique_ptr<aclTensor, aclnnStatus (*)(const aclTensor*)> x1TensorPtr(x1, aclDestroyTensor); std::unique_ptr<void, aclError (*)(void*)> x1deviceAddrPtr(x1DeviceAddr, aclrtFree); CHECK_RET(ret == ACL_SUCCESS, return ret); // 创建x2 aclTensor ret = CreateAclTensor(x2HostData, x2Shape, &x2DeviceAddr, aclDataType::ACL_FLOAT8_E4M3FN, &x2); std::unique_ptr<aclTensor, aclnnStatus (*)(const aclTensor*)> x2TensorPtr(x2, aclDestroyTensor); std::unique_ptr<void, aclError (*)(void*)> x2deviceAddrPtr(x2DeviceAddr, aclrtFree); CHECK_RET(ret == ACL_SUCCESS, return ret); // 创建x1Scale aclTensor ret = CreateAclTensor(x1ScaleHostData, x1ScaleShape, &x1ScaleDeviceAddr, aclDataType::ACL_FLOAT, &x1Scale); std::unique_ptr<aclTensor, aclnnStatus (*)(const aclTensor*)> x1ScaleTensorPtr(x1Scale, aclDestroyTensor); std::unique_ptr<void, aclError (*)(void*)> x1ScaledeviceAddrPtr(x1ScaleDeviceAddr, aclrtFree); CHECK_RET(ret == ACL_SUCCESS, return ret); // 创建x2Scale aclTensor ret = CreateAclTensor(x2ScaleHostData, x2ScaleShape, &x2ScaleDeviceAddr, aclDataType::ACL_FLOAT, &x2Scale); std::unique_ptr<aclTensor, aclnnStatus (*)(const aclTensor*)> x2ScaleTensorPtr(x2Scale, aclDestroyTensor); std::unique_ptr<void, aclError (*)(void*)> x2ScaledeviceAddrPtr(x2ScaleDeviceAddr, aclrtFree); CHECK_RET(ret == ACL_SUCCESS, return ret); // 创建out aclTensor ret = CreateAclTensor(outHostData, outShape, &outDeviceAddr, aclDataType::ACL_BF16, &out); std::unique_ptr<aclTensor, aclnnStatus (*)(const aclTensor*)> outTensorPtr(out, aclDestroyTensor); std::unique_ptr<void, aclError (*)(void*)> outdeviceAddrPtr(outDeviceAddr, aclrtFree); CHECK_RET(ret == ACL_SUCCESS, return ret); aclIntArray* permX1 = aclCreateIntArray(permX1Series.data(), permX1Series.size()); aclIntArray* permX2 = aclCreateIntArray(permX2Series.data(), permX2Series.size()); aclIntArray* permY = aclCreateIntArray(permYSeries.data(), permYSeries.size()); uint64_t workspaceSize = 0; aclOpExecutor* executor = nullptr; std::unique_ptr<void, aclError (*)(void*)> executorAddrPtr(nullptr, aclrtFree); int32_t batchSplitFactor = 1; int32_t groupSize = 0; int32_t dtype = 27; // bf16 // aclnnTransposeQuantBatchMatMul接口调用示例 // 3. 调用CANN算子库API,需要修改为具体的API名称 // 调用aclnnTransposeQuantBatchMatMul第一段接口 ret = aclnnTransposeQuantBatchMatMulGetWorkspaceSize( x1, x2, (const aclTensor*)nullptr, x1Scale, x2Scale, dtype, groupSize, permX1, permX2, permY, batchSplitFactor, out, &workspaceSize, &executor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnTransposeQuantBatchMatMulGetWorkspaceSize failed. ERROR: %d\n", ret); return ret); // 根据第一段接口计算出的workspaceSize申请device内存 void* workspaceAddr = nullptr; if (workspaceSize > 0) { ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret); executorAddrPtr.reset(workspaceAddr); } // 调用aclnnTransposeQuantBatchMatMul第二段接口 ret = aclnnTransposeQuantBatchMatMul(workspaceAddr, workspaceSize, executor, stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnTransposeQuantBatchMatMul failed. ERROR: %d\n", ret); return ret); // 4. (固定写法)同步等待任务执行结束 ret = aclrtSynchronizeStream(stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret); // 5. 获取输出的值,将Device侧内存上的结果拷贝至Host侧,需要根据具体API的接口定义修改 auto size = GetShapeSize(outShape); std::vector<uint16_t> resultData(size, 0); // bf16 ret = aclrtMemcpy( resultData.data(), resultData.size() * sizeof(resultData[0]), outDeviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret); float resultDataBF16 = 0; for (int64_t i = 0; i < size; i++) { resultDataBF16 = bf16_to_float(resultData[i]); LOG_PRINT("result[%ld] is: %f\n", i, resultDataBF16); } return ACL_SUCCESS; } int main() { // 1. (固定写法)device/stream初始化,参考acl API手册 // 根据自己的实际device填写deviceId int32_t deviceId = 0; aclrtStream stream; auto ret = AclnnTransposeQuantBatchMatmulTest(deviceId, stream); CHECK_FREE_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnTransposeQuantBatchMatMulTest failed. ERROR: %d\n", ret); return ret); Finalize(deviceId, stream); return 0; }

【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

http://www.jsqmd.com/news/797435/

相关文章:

  • 2026年云南省陆良师宗地区汽车贴膜行业横向测评 - 速递信息
  • LayerDivider:如何用AI算法实现智能图像分层,提升设计效率10倍?
  • svg-captcha 字体定制教程:如何加载和使用自定义字体
  • 别再死记公式了!用Python+LTspice快速搞定LC滤波器设计(附仿真文件)
  • 别再死记硬背了!用Python手动画图,彻底搞懂TensorFlow的tf.nn.depth_to_space
  • Windows文件同步终极指南:SyncTrayzor让多设备文件管理变得简单高效
  • 智能水表、血糖仪、工业HMI:STM32L152ZET6的超低功耗MCU应用版图
  • 【Google官方未公开】Gemini字幕引擎底层架构拆解:Transformer-Lite模型+端侧ASR协同机制(附实测WER对比表)
  • 【信息科学与工程学】【云计算】边缘-云协同 第五篇 边缘-云协同资源编排技术——03容器与调度器技术
  • 探索2026年济南留学中介十强,哪些机构成功率高 - 速递信息
  • Java面试题全网最全整理(附答案),已按模块分类,从基础到实战一篇搞定
  • 大连奢侈品包包变现服务测评:五家平台分级解析,收的顶以专业引领行业 - 奢侈品回收测评
  • py每日spider案例之某插件请求接口加密参数逆向(aes 难度一般)
  • HYMiniMall项目实战:如何基于现有架构快速扩展新功能模块的完整指南
  • CANN Ascend C SetStride API
  • CANN/asc-devkit SetStartPosition API文档
  • 紧急更新!Midjourney v6.2.1已悄然调整Pokeberry印相底层LUT加载机制:3小时内必须重校准的2个关键变量
  • 重庆川岳机电设备:高新区可靠的设备吊装哪家好 - LYL仔仔
  • Gemini如何重构Google搜索体验:3个被90%开发者忽略的AI增强接口与调用陷阱
  • 2026年 不锈钢工程厂家推荐排行榜:房屋、商场、写字楼、会展中心等多场景不锈钢工程优质之选! - 速递信息
  • CANN/Ascend C WholeReduceSum API文档
  • Qt 软件外包开发流程
  • 3分钟上手FanControl:让Windows电脑风扇更智能更安静
  • Springboot+Vue3|毕业设计美食分享平台(源码)
  • 2026交调系统排行榜,广州聚杰芯科凭多系列产品覆盖全场景监测 - 品牌速递
  • 2026年云南省汽车后市场观察:V-KOOL威固陆良金锋旗舰店打造本地化贴膜服务标杆 - 速递信息
  • 温州市方氏建材:苍南专业的建材批发工厂 - LYL仔仔
  • 从零到一:基于STM32F030的SPI驱动74HC595实战解析
  • CANN/asc-devkit SPM缓冲区写入API
  • 黄金变现选对平台少走弯路,厦门 5 家机构测评:收的顶全国连锁更放心 - 奢侈品回收测评