CANN/ops-nn AddRmsNormDynamicQuant算子
AddRmsNormDynamicQuant
【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
| Kirin X90 处理器系列产品 | √ |
| Kirin 9030 处理器系列产品 | √ |
功能说明
算子功能:RmsNorm算子是大模型常用的归一化操作,相比LayerNorm算子,其去掉了减去均值的部分。DynamicQuant算子则是为输入张量进行对称动态量化的算子。AddRmsNormDynamicQuantV2算子将RmsNorm前的Add算子和RmsNorm归一化输出给到的1个或2个DynamicQuant算子融合起来,减少搬入搬出操作。AddRmsNormDynamicQuant算子相较于AddRmsNormDynamicQuantV2在RmsNorm计算过程中增加了偏置项betaOptional参数,即计算对应公式中的beta,以及新增输出配置项output_mask参数,用于配置是否输出对应位置的量化结果。
计算公式:
$$ x=x_{1}+x_{2} $$
$$ y = \operatorname{RmsNorm}(x)=\frac{x}{\operatorname{Rms}(\mathbf{x})}\cdot gamma+beta, \quad \text { where } \operatorname{Rms}(\mathbf{x})=\sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2+epsilon} $$
$$ input1 =\begin{cases} y\cdot smoothScale1Optional & \ \ smoothScale1Optional\ != null \ y & \ \ smoothScale1Optional\ = null \end{cases} $$
$$ input2 =\begin{cases} y\cdot smoothScale2Optional & \ \ smoothScale2Optional\ != null \ y & \ \ smoothScale2Optional\ = null \end{cases} $$
$$ scale1Out=\begin{cases} row_max(abs(input1))/127 & outputMask[0]=True\ ||\ outputMask\ = null \ 无效输出 & outputMask[0]=False \end{cases} $$
$$ y1Out=\begin{cases} round(input1/scale1Out) & outputMask[0]=True\ ||\ outputMask\ = null \ 无效输出 & outputMask[0]=False \end{cases} $$
$$ scale2Out=\begin{cases} row_max(abs(input2))/127 & outputMask[1]=True\ ||\ (outputMask\ = null\ &\ smoothScale1Optional\ != null\ &\ smoothScale2Optional\ != null) \ 无效输出 & outputMask[1]=False\ ||\ (outputMask\ = null\ &\ (smoothScale1Optional\ = null\ ||\ smoothScale2Optional\ = null)) \end{cases} $$
$$ y2Out=\begin{cases} round(input2/scale2Out) & outputMask[1]=True\ ||\ (outputMask\ = null\ &\ smoothScale1Optional\ != null\ &\ smoothScale2Optional\ != null)\ 无效输出 & outputMask[1]=False\ ||\ (outputMask\ = null\ &\ (smoothScale1Optional\ = null\ ||\ smoothScale2Optional\ = null)) \end{cases} $$
公式中的row_max代表每行求最大值。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x1 | 输入 | 表示标准化过程中的源数据张量,对应公式中的`x1`。 | FLOAT16、BFLOAT16 | ND |
| x2 | 输入 | 表示标准化过程中的源数据张量,对应公式中的`x2`。 | FLOAT16、BFLOAT16 | ND |
| gamma | 输入 | 表示标准化过程中的权重张量,对应公式中的`gamma`。shape需要与x1最后一维一致。 | FLOAT16、BFLOAT16 | ND |
| smooth_scale1 | 可选输入 | 表示量化过程中得到y1使用的smoothScale张量,对应公式中的`smoothScale1Optional`。 | FLOAT16、BFLOAT16 | ND |
| smooth_scale2 | 可选输入 | 表示量化过程中得到y2使用的smoothScale张量,对应公式中的`smoothScale2Optional`。 | FLOAT16、BFLOAT16 | ND |
| beta | 可选输入 | 表示标准化过程中的偏置项,对应公式中的`beta`。 | FLOAT16、BFLOAT16 | ND |
| epsilon | 可选属性 |
| FLOAT32 | - |
| output_mask | 可选属性 |
| LISTBOOL | - |
| y1 | 输出 | 表示量化输出Tensor,对应公式中的`y1Out`。 | INT8、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN、INT4 | ND |
| y2 | 输出 | 表示量化输出Tensor,对应公式中的`y2Out`。 | INT8、HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN、INT4 | ND |
| x | 输出 | 表示x1和x2的和,对应公式中的`x`。 | FLOAT16、BFLOAT16 | ND |
| scale1 | 输出 | 第一路量化的输出,对应公式中的`scale1Out`。 | FLOAT32 | ND |
| scale2 | 输出 | 第二路量化的输出,对应公式中的`scale2Out`。 | FLOAT32 | ND |
Ascend 950PR/Ascend 950DT :
- 暂不支持可选属性
output_mask的配置。 - 输出参数
y1、y2的数据类型不支持INT4。
- 暂不支持可选属性
Atlas A3 训练系列产品/Atlas A3 推理系列产品 、 Atlas A2 训练系列产品/Atlas A2 推理系列产品 :
输出参数
y1、y2的数据类型仅支持INT4、INT8。Kirin X90/Kirin 9030处理器系列产品: x1、x2、gamma、smooth_scale1、smooth_scale2、beta和x的数据类型不支持BFLOAT16。 y1和y2不支持HIFLOAT8、FLOAT8_E5M2、FLOAT8_E4M3FN、INT4。
约束说明
当output_mask不为空时,参数smooth_scale1有值时,则output_mask[0]必须为True。参数smooth_scale2有值时,则output_mask[1]必须为True。
当output_mask为空时,参数smooth_scale2有值时,参数smooth_scale1不能为空。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_add_rms_norm_dynamic_quant | 通过aclnnAddRmsNormDynamicQuant接口方式调用AddRmsNormDynamicQuant算子。 |
| aclnn接口 | test_aclnn_add_rms_norm_dynamic_quant_v2 | 通过aclnnAddRmsNormDynamicQuantV2接口方式调用AddRmsNormDynamicQuant算子。 |
| 图模式 | - | 通过算子IR构图方式调用AddRmsNormDynamicQuant算子。 |
【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
