CANN/pypto矩阵乘法API文档
pypto.matmul
【免费下载链接】pyptoPyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
实现input 、mat2矩阵的矩阵乘运算,计算公式为:out = input @ mat2
- input 、mat2为源操作数,input 为左矩阵;mat2为右矩阵
- out 为目的操作数,存放矩阵乘结果的矩阵
注意事项
- 左右矩阵数据类型必须一致:matmul 的左右矩阵数据类型必须相同(如 BF16+BF16、FP16+FP16),不支持混合输入(如 BF16+FP32),FP8数据类型除外
- 推荐使用低精度输入:BF16/FP16 输入直接 matmul 输出 FP32,比先 cast 到 FP32 再 matmul 性能更好,且精度相当
- 避免不必要的 cast:将 BF16 升级到 FP32 再进行 matmul 计算不会有精度提升,反而会产生额外的数据搬移开销
- 利用随路 transpose:matmul 支持
a_trans和b_trans参数,可以在矩阵乘时随路完成转置,避免额外调用 transpose 操作 - 必须先设置 TileShape:调用 matmul 接口前需要通过
set_cube_tile_shapes设置 M、N、K 轴上的切分大小
函数原型
matmul(input, mat2, out_dtype, *, a_trans = False, b_trans = False, c_matrix_nz = False, extend_params=None) -> Tensor参数说明
表1:API参数说明
| 参数名 | 输入/输出 | 说明 |
|---|---|---|
| input | 输入 | 表示输入左矩阵,不支持输入空Tensor。 输入数据类型支持情况详见表3。 支持的矩阵维度:2维、3维、4维,且左右矩阵维度需保持一致。 输入矩阵支持的Format为:TILEOP_ND,TILEOP_NZ(DT_FP32,DT_FP8E5M2,DT_HF8输入不支持TILEOP_NZ格式)。 内轴外轴:当输入矩阵input非转置时,对应数据排布为[M, K],此时外轴为M,内轴为K;当输入矩阵input转置时,对应数据排布为[K, M],此时外轴为K,内轴为M; 当Format为TILEOP_ND(ND格式)时,外轴范围为[1, 2^31 - 1],内轴范围为[1, 65535]。 当Format为TILEOP_NZ(NZ格式)时,其Shape维度需满足内轴32字节对齐,外轴16元素对齐。 在使用pypto.view接口的场景,应保证传入View的Shape维度也满足内轴32字节对齐,外轴16元素对齐。 |
| mat2 | 输入 | 表示输入右矩阵,不支持输入空Tensor。 输入数据类型支持情况详见表3。 支持的矩阵维度:2维、3维、4维,且左右矩阵维度需保持一致。 输入矩阵支持的Format为:TILEOP_ND,TILEOP_NZ(DT_FP32,DT_FP8E5M2,DT_HF8输入不支持TILEOP_NZ格式)。 内轴外轴:当输入矩阵mat2非转置时,对应数据排布为[K, N],此时外轴为K,内轴为N;当输入矩阵mat2转置时,对应数据排布为[N, K],此时外轴为N,内轴为K; 当Format为TILEOP_ND(ND格式)时,外轴范围为[1, 2^31 - 1],内轴范围为[1, 65535]。 当Format为TILEOP_NZ(NZ格式)时,其Shape维度需满足内轴32字节对齐,外轴16元素对齐。 在使用pypto.view接口的场景,应保证传入View的Shape维度也满足内轴32字节对齐,外轴16元素对齐。 |
| out_dtype | 输出 | 表示输出矩阵数据类型。输入输出数据类型支持情况详见表3。 |
| a_trans | 输入 | 参数a_trans表示输入左矩阵是否转置,默认为False。 |
| b_trans | 输入 | 参数b_trans表示输入右矩阵是否转置,默认为False。 |
| c_matrix_nz | 输入 | 参数c_matrix_nz表示输出矩阵的Format是否采用NZ格式,默认为False,当前仅支持设置False,即输出矩阵仅支持ND格式。 |
| extend_params | 输入 | 支持bias、fixpipe反量化及TF32舍入模式功能,详见表2。 bias、fixpipe反量化输入输出数据类型支持情况详见表3,表4。 - 数据类型为字典格式。 - 此参数与其内部参数均为可选参数。 |
表2:extend_params参数说明
| 参数名 | 说明 |
|---|---|
| scale | 表示pertensor量化场景(使用同一个缩放因子将高精度数映射到低精度数)输出矩阵反量化的参数。 输入为float类型,取1位符号位 + 8位指数位 + 10位尾数位参与运算。 输入输出数据类型支持情况详见表4。 不支持叠加多核切k功能。 |
| scale_tensor | 表示perchannel量化场景(对每一个输出通道独立计算一套量化参数)输出矩阵反量化的矩阵。 scale_tensor输入固定为uint64_t 的Tensor。计算时会转换uint64_t为float类型的低32位bit后,取1位符号位 + 8位指数位 + 10位尾数位参与运算。 输入输出数据类型支持情况详见表4。 scale_tensor的第一维度必须置1,且N维度需要与mat2矩阵的N维度相等。 scale_tensor只支持ND格式。 仅支持矩阵维度为2维场景。 不支持叠加多核切k功能。 |
| bias_tensor | 表示偏置矩阵。 输入为Tensor类型。 输入输出数据类型支持情况详见表3。 bias_tensor只支持ND格式。 bias_tensor的第一维度应置1,且N维度需要与mat2矩阵的N维度相等。 仅支持矩阵维度为2维场景。 不支持叠加多核切k功能。 |
| relu_type | 表示输出矩阵是否进行ReLu操作。 输入为ReLuType类型。 支持RELU和NO_RELU两种模式。 仅支持矩阵维度为2维场景。 |
| trans_mode | 表示是否使能TF32计算及TF32舍入模式。 输入为TransMode类型,支持以下三种模式: • CAST_NONE:不使能float数据类型转换为TF32数据类型。 • CAST_RINT:使能float数据类型转换为TF32数据类型,舍入规则:舍入到最近整数,中间值时舍入到偶数。 • CAST_ROUND:使能float数据类型转换为TF32数据类型,舍入规则:舍入到最近整数,中间值时远离零舍入。 仅支持输入左右矩阵和输出矩阵数据类型均为DT_FP32时设置。 仅支持矩阵维度为2维场景。 |
表3: Matmul支持的数据类型
| input | mat2 | out_dtype | bias_tensor | 产品支持 |
|---|---|---|---|---|
| DT_FP16 | DT_FP16 | DT_FP16,DT_FP32 | DT_FP16,DT_FP32 | 全系列 |
| DT_BF16 | DT_BF16 | DT_BF16,DT_FP32 | DT_FP32,DT_BF16(DT_BF16仅950PR/DT支持) | 全系列 |
| DT_FP32 | DT_FP32 | DT_FP32 | DT_FP32 | 全系列 |
| DT_INT8 | DT_INT8 | DT_INT32 | DT_INT32 | 全系列 |
| DT_FP8E5M2 | DT_FP8E5M2,DT_FP8E4M3 | DT_FP16,DT_BF16,DT_FP32 | DT_FP16,DT_BF16,DT_FP32 | 仅950PR/DT |
| DT_FP8E4M3 | DT_FP8E5M2,DT_FP8E4M3 | DT_FP16,DT_BF16,DT_FP32 | DT_FP16,DT_BF16,DT_FP32 | 仅950PR/DT |
| DT_HF8 | DT_HF8 | DT_FP16,DT_BF16,DT_FP32 | DT_FP16,DT_BF16,DT_FP32 | 仅950PR/DT |
表4: 反量化支持的数据类型
| input | mat2 | out_dtype | 产品支持 |
|---|---|---|---|
| DT_INT8 | DT_INT8 | DT_FP16 | 全系列 |
返回值说明
返回值为out 矩阵(Tensor)。
约束说明
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:不支持DT_HF8,DT_FP8E5M2,DT_FP8E4M3,不支持extend_params中的trans_mode参数。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:不支持DT_HF8,DT_FP8E5M2,DT_FP8E4M3,不支持extend_params中的trans_mode参数。
- 调用matmul接口前需要通过pypto.set_cube_tile_shapes设置M、N、K轴上的切分大小
- 当矩阵维度为3维或者4维时,需要调用pypto.set_vec_tile_shapes接口设置vector的TileShape切分,如未设置,接口内部会设置2维的vec_tile_shape,其值为128,128。
- 调用matmul接口的输入为调用pypto.reshape后的NZ格式时,需要调用pypto.set_matrix_size接口设置pypto.reshape前的输入到matmul的原始Shape的m,k,n值。
- 调用matmul接口的输入矩阵维度为3维/4维并且数据格式为NZ格式时,需要调用pypto.set_matrix_size接口设置输入到matmul的原始Shape的m,k,n值。
调用示例
# 基本矩阵乘 a1 = pypto.tensor([16, 32], pypto.DT_BF16, "tensor_a") b1 = pypto.tensor([32, 64], pypto.DT_BF16, "tensor_b") out1 = pypto.matmul(a1, b1, pypto.DT_BF16) # 批量矩阵乘 a2 = pypto.tensor((2, 16, 32), pypto.DT_FP16, "tensor_a") b2 = pypto.tensor((2, 32, 16), pypto.DT_FP16, "tensor_b") out2 = pypto.matmul(a2, b2, pypto.DT_FP16) # 批次广播 a3 = pypto.tensor((1, 32, 64), pypto.DT_FP32, "tensor_a") b3 = pypto.tensor((3, 64, 16), pypto.DT_FP32, "tensor_b") out3 = pypto.matmul(a3, b3, pypto.DT_FP32) # 叠加Bias a = pypto.tensor((16, 32), pypto.DT_FP16, "tensor_a") b = pypto.tensor((32, 64), pypto.DT_FP16, "tensor_b") bias = pypto.tensor((1, 64), pypto.DT_FP16, "tensor_bias") extend_params = {'bias_tensor': bias} pypto.matmul(a, b, pypto.DT_FP32, extend_params=extend_params) # 反量化 a = pypto.tensor((16, 32), pypto.DT_INT8, "tensor_a") b = pypto.tensor((32, 64), pypto.DT_INT8, "tensor_b") extend_params = {'scale': 0.2} pypto.matmul(a, b, pypto.DT_BF16, extend_params=extend_params) # 反量化叠加RELU extend_params = {'scale': 0.2, 'relu_type': pypto.ReLuType.RELU} pypto.matmul(a, b, pypto.DT_BF16, extend_params=extend_params) # 反量化叠加RELU scale_tensor = pypto.tensor((1, 64), pypto.DT_UINT64, "tensor_scale") extend_params = {'scale_tensor': scale_tensor, 'relu_type': pypto.ReLuType.RELU} pypto.matmul(a, b, pypto.DT_BF16, extend_params=extend_params) # TF32计算模式(仅950PR/DT) a = pypto.tensor((16, 32), pypto.DT_FP32, "tensor_a") b = pypto.tensor((32, 64), pypto.DT_FP32, "tensor_b") extend_params = {'trans_mode': pypto.TransMode.CAST_ROUND} pypto.matmul(a, b, pypto.DT_FP32, extend_params=extend_params)【免费下载链接】pyptoPyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
