当前位置: 首页 > news >正文

CANN/ops-tensor量化矩阵乘法调度器

Block Scheduler Quant Batch Matmul

【免费下载链接】ops-tensorops-tensor 是 CANN (Compute Architecture for Neural Networks)算子库中提供张量类计算的基础算子库,采用模块化设计,支持灵活的算子开发和管理。项目地址: https://gitcode.com/cann/ops-tensor

代码位置

功能说明

量化 Batch Matmul 调度器,支持多 Batch 维度切分、尾块切分、负载均衡、Z 型扫描。适用于 Quant Batch Matmul MX Kernel 场景,支持 MxFP4/MxFP8 量化数据类型。

继承自:Block Scheduler 公共框架

模板参数

template < class ProblemShape_, // 问题规模类型 uint64_t FullLoadMode_, // 全载模式(0=非全载,1=A全载) class LayoutA_, // A 矩阵布局类型 class LayoutB_, // B 矩阵布局类型 class AType_> // A 矩阵数据类型(用于判断 C0_SIZE) class BlockSchedulerQuantBatchMatmulV3;

全载模式

支持两种全载模式:

  • FullLoadMode_ = 0:非全载模式(默认)
  • FullLoadMode_ = A_FULL_LOAD_MODE(1):A 矩阵全载模式

C0_SIZE 计算

根据数据类型自动计算 C0 对齐大小:

  • FP4(fp4x2_e2m1_t, fp4x2_e1m2_t):C0_SIZE = 64
  • FP8(fp8_e5m2_t, fp8_e4m3fn_t):C0_SIZE = 32

转置判断

根据布局类型判断转置:

  • transAIsTrans<LayoutA_>::value
  • transBIsTrans<LayoutB_>::value

尾块切分

支持尾块 tile 切分:

  • mTailTile:M 轴尾块切分数量
  • nTailTile:N 轴尾块切分数量
  • totalTailTile:总尾块切分数量(mTailTile × nTailTile)

Z 型扫描

使用 Z 型扫描策略:

  • WINDOW_LEN = 4:扫描窗口大小
  • 正向扫描:偶数行(rowIdx % 2 == 0)
  • 反向扫描:奇数行(rowIdx & 1)

负载均衡

支持负载均衡配置:

  • mBaseNormCnt:M 轴正常 tile 数量
  • mBaseTailMain:M 轴尾块主尺寸
  • nBaseNormCnt:N 轴正常 tile 数量
  • nBaseTailMain:N 轴尾块主尺寸

特殊静态常量

常量说明
C0_SIZEC0 对齐大小(FP4: 64,FP8: 32)
WINDOW_LENZ 型扫描窗口长度(4)
transAA 矩阵是否转置
transBB 矩阵是否转置

特殊类型别名

类型说明
BlockShapeBlock 形状:Shape<int64_t, int64_t, int64_t, int64_t>
BlockCoordBlock 坐标:Coord<int64_t, int64_t, int64_t, int64_t>
ProblemShape问题规模类型(模板参数)
ATypeA 矩阵数据类型(模板参数)

特殊数据结构

Params

struct Params { int64_t baseM; // L0 M 维度 base 大小 int64_t baseN; // L0 N 维度 base 大小 int64_t mTailTile; // M 轴尾块切分数量 int64_t nTailTile; // N 轴尾块切分数量 int64_t mBaseTailSplitCnt; // M 轴尾块 L1 切分数量 int64_t nBaseTailSplitCnt; // N 轴尾块 L1 切分数量 int64_t mTailMain; // M 轴尾块主尺寸 int64_t nTailMain; // N 轴尾块主尺寸 };

特殊成员变量

变量说明
m_, n_, k_问题规模
baseM_, baseN_L0 base 形状
mCnt_, nCnt_, totalCnt_tile 数量
mBaseNormCnt_, nBaseNormCnt_正常 tile 数量
mBaseTailMain_, nBaseTailMain_尾块主尺寸
mBaseTailLast_, nBaseTailLast_尾块最后尺寸
mCoreNum_, mTailCoreNum_M 轴核心数量、尾核心数量
blockIdx_, blockNum_当前 Block 索引、总 Block 数量
startBlockIdx_, endBlockIdx_起始/结束 Block 索引
roundIdx_, round_当前轮次、总轮次
mTailTile_, nTailTile_, totalTailTile_尾块切分数量
mSplitAddrOffset_, nSplitAddrOffset_尾块切分偏移
mainRow_主行数

特殊成员方法

构造函数

__aicore__ inline BlockSchedulerQuantBatchMatmulV3(const ProblemShape& shape, const Params& params)

功能:初始化 BlockSchedulerQuantBatchMatmulV3,计算 tile 切分、尾块参数、轮次等。 参数说明: | 参数 | 类型 | 说明 | |------|------|------| | shape | ProblemShape | 问题规模(m, n, k)| | params | Params | 调度参数 |

执行流程:

  1. 设置问题规模:m_,n_,k_,baseM_,baseN_
  2. 计算 tile 数量:mCnt_,nCnt_,totalCnt_
  3. 计算扫描窗口:mCoreNum_,mainRow_,mTailCoreNum_
  4. 计算轮次:endBlockIdx_,round_
  5. 计算尾块参数:根据transAtransB计算尾块尺寸

UpdateTailTile

__aicore__ inline void UpdateTailTile(uint32_t mTailTile, uint32_t nTailTile)

功能:更新尾块切分数量,重新计算结束 Block 索引和轮次。 参数说明: | 参数 | 类型 | 说明 | |------|------|------| | mTailTile | uint32_t | M 轴尾块切分数量 | | nTailTile | uint32_t | N 轴尾块切分数量 |

GetTotalCnt

__aicore__ inline int64_t GetTotalCnt()

功能:返回总 tile 数量(totalCnt_)。

GetEndBlockIdx

__aicore__ inline int64_t GetEndBlockIdx()

功能:返回结束 Block 索引(endBlockIdx_)。

CalSingleCoreShapeByCoord

__aicore__ inline void CalSingleCoreShapeByCoord(int64_t& singleCoreM, int64_t& singleCoreN, const BlockCoord& blockCoord)

功能:根据 Block 坐标计算单核形状(处理尾块)。 参数说明: | 参数 | 类型 | 说明 | |------|------|------| | singleCoreM | int64_t& | 单核 M 维度(原地修改) | | singleCoreN | int64_t& | 单核 N 维度(原地修改) | | blockCoord | BlockCoord | Block 坐标 |

GetBlockShape

template <QuantBatchMatmul::QuantMode aQuantMode, QuantBatchMatmul::QuantMode bQuantMode, bool weightNz = false> __aicore__ inline BlockShape GetBlockShape(BlockCoord blockCoord)

功能:返回当前 Block 的形状,支持量化模式和 NZ 格式。 参数说明: | 参数 | 类型 | 说明 | |------|------|------| | aQuantMode | QuantMode | A 矩阵量化模式(PERGROUP/PERBLOCK) | | bQuantMode | QuantMode | B 矩阵量化模式(PERGROUP/PERBLOCK) | | weightNz | bool | B 矩阵是否为 NZ 格式 | | blockCoord | BlockCoord | Block 坐标 |

返回值:BlockShape {singleCoreM, singleCoreN, mSplitAddrOffset_, nSplitAddrOffset_}

特殊逻辑:

  • 尾块切分判断totalTailTile_ > 1 && roundIdx_ == round_
  • FP4 对齐:FP4 + transA 时 M 对齐到 2,FP4 + !transB 时 N 对齐到 2
  • PERBLOCK 模式:对齐到 PER_BLOCK_SIZE 或 2 的幂次方
  • NZ 格式对齐:根据 transB 对齐到 C0_SIZE 或 BLOCK_CUBE

GetLoadBalanceInfo

__aicore__ inline AscendC::Std::tuple<uint32_t, uint32_t, uint32_t, uint32_t> GetLoadBalanceInfo()

功能:返回负载均衡信息{mBaseNormCnt_, mBaseTailMain_, nBaseNormCnt_, nBaseTailMain_}

UpdateNextBatchBlockRoundParams

__aicore__ inline void UpdateNextBatchBlockRoundParams()

功能:更新下一 Batch 的 Block 轮次参数。 执行流程:

  1. 更新startBlockIdx_endBlockIdx_
  2. 重置roundIdx_ = 0
  3. 重新计算round_

GetTileIdx

__aicore__ inline bool GetTileIdx(BlockCoord& blockCoord)

功能:获取当前轮次的 tile 索引,更新 Block 坐标。 参数说明: | 参数 | 类型 | 说明 | |------|------|------| | blockCoord | BlockCoord& | Block 坐标(原地修改) |

返回值:

  • true:当前轮次有效,返回 tile 坐标
  • false:当前轮次结束

执行流程:

  1. 判断轮次结束:roundIdx_ >= round_
  2. 计算 tile 索引:根据全载模式计算
  3. Z 型扫描:计算blockCoordM,blockCoordN
  4. 反向扫描:奇数行反向(blockCoordN = nCnt_ - 1 - blockCoordN
  5. 更新roundIdx_++

GetTileCoord

__aicore__ inline void GetTileCoord(const BlockCoord& blockCoord, int64_t& mPos, int64_t& nPos)

功能:根据 Block 坐标计算 GM 地址偏移。 参数说明: | 参数 | 类型 | 说明 | |------|------|------| | blockCoord | BlockCoord | Block 坐标 | | mPos | int64_t& | M 轴 GM 偏移(原地修改) | | nPos | int64_t& | N 轴 GM 偏移(原地修改) |

调用示例

组件组装

using ProblemShape = AscendC::Te::Shape<int64_t, int64_t, int64_t>; using LayoutA = AscendC::Te::NZLayoutPtn; using LayoutB = AscendC::Te::NZLayoutPtn; using AType = fp4x2_e2m1_t; constexpr uint64_t FULL_LOAD_MODE = 0; using BlockScheduler = Blaze::Gemm::Block::BlockSchedulerQuantBatchMatmulV3< ProblemShape, FULL_LOAD_MODE, LayoutA, LayoutB, AType>;

参数准备

BlockScheduler::Params params = { baseM, // L0 M 维度 base(如 128) baseN, // L0 N 维度 base(如 128) mTailTile, // M 轴尾块切分数量(如 1) nTailTile, // N 轴尾块切分数量(如 1) mBaseTailSplitCnt, // M 轴尾块 L1 切分数量(如 1) nBaseTailSplitCnt, // N 轴尾块 L1 切分数量(如 1) mTailMain, // M 轴尾块主尺寸(如 1) nTailMain // N 轴尾块主尺寸(如 1) };

组件初始化

ProblemShape shape{m, n, k}; BlockScheduler scheduler(shape, params);

更新尾块切分

scheduler.UpdateTailTile(mTailTile, nTailTile);

获取 tile 数量

int64_t totalCnt = scheduler.GetTotalCnt();

Tile 循环处理

BlockCoord blockCoord{0, 0, 0, 0}; while (scheduler.GetTileIdx(blockCoord)) { // 获取 Block 形状 constexpr auto aQuantMode = QuantBatchMatmul::QuantMode::PERGROUP_MODE; constexpr auto bQuantMode = QuantBatchMatmul::QuantMode::PERGROUP_MODE; constexpr bool weightNz = true; auto blockShape = scheduler.GetBlockShape<aQuantMode, bQuantMode, weightNz>(blockCoord); int64_t singleCoreM = Get<0>(blockShape); int64_t singleCoreN = Get<1>(blockShape); int64_t mSplitOffset = Get<2>(blockShape); int64_t nSplitOffset = Get<3>(blockShape); // 获取 GM 地址偏移 int64_t mPos, nPos; scheduler.GetTileCoord(blockCoord, mPos, nPos); // 执行 BlockMmadMX 计算 // ... }

更新下一 Batch

scheduler.UpdateNextBatchBlockRoundParams();

获取负载均衡信息

auto loadBalanceInfo = scheduler.GetLoadBalanceInfo(); uint32_t mBaseNormCnt = std::get<0>(loadBalanceInfo); uint32_t mBaseTailMain = std::get<1>(loadBalanceInfo); uint32_t nBaseNormCnt = std::get<2>(loadBalanceInfo); uint32_t nBaseTailMain = std::get<3>(loadBalanceInfo);

数据流

Tile 切分流程

问题规模 (m, n, k) ↓ L0 tile 切分 (baseM, baseN) ↓ tile 数量计算 (mCnt × nCnt) ↓ 尾块参数计算 (mBaseNormCnt, mBaseTailMain, nBaseNormCnt, nBaseTailMain) ↓ 轮次计算 (round, startBlockIdx, endBlockIdx) ↓ Z 型扫描 (mCoreNum, mainRow, mTailCoreNum) ↓ Block 形状/坐标 (singleCoreM, singleCoreN, mPos, nPos)

量化模式对齐流程

GetBlockShape<QuantMode, QuantMode, weightNz> ↓ 尾块切分判断 (totalTailTile > 1 && roundIdx == round) ↓ FP4 对齐:transA → M 对齐到 2,!transB → N 对齐到 2 ↓ PERBLOCK 模式:对齐到 PER_BLOCK_SIZE 或 2 的幂次方 ↓ NZ 格式对齐:!transB → C0_SIZE,transB → BLOCK_CUBE ↓ 返回 BlockShape {singleCoreM, singleCoreN, mSplitAddrOffset, nSplitAddrOffset}

Z 型扫描流程

tileIdx 计算 ↓ rowIdx = tileIdx / nCnt / mCoreNum ↓ rowIdx < mainRow:blockCoordM = rowIdx × mCoreNum + tileIdx % mCoreNum ↓ rowIdx == mainRow:尾窗口计算 ↓ rowIdx & 1:反向扫描(blockCoordN = nCnt - 1 - blockCoordN)

性能优化建议

baseM/baseN 配置

  • 建议值:根据量化数据类型选择(如 128)
  • C0 对齐:确保 baseM/baseN 对齐到 C0_SIZE(FP4: 64,FP8: 32)

尾块切分配置

  • mTailTile/nTailTile:建议尾块切分数量不超过 4
  • mBaseTailSplitCnt/nBaseTailSplitCnt:建议 L1 尾块切分数量为 1(不切分)

全载模式选择

  • 非全载模式(FullLoadMode = 0):适用于一般场景
  • A 全载模式(FullLoadMode = A_FULL_LOAD_MODE):适用于大 K、小 M 场景

量化模式选择

  • PERGROUP_MODE:per-group 量化,适用于小规模量化
  • PERBLOCK_MODE:per-block 量化,对齐要求更高

NZ 格式优化

  • weightNz = true:B 矩阵 NZ 格式,提升 L1/L0 搬运效率
  • 对齐要求:根据 transB 选择 C0_SIZE 或 BLOCK_CUBE

适用场景

  • Quant Batch Matmul MX Kernel:量化 Batch Matmul
  • MxFP4/MxFP8 量化:支持 FP4 E2M1/E1M2 和 FP8 E5M2/E4M3FN
  • 多 Batch 维度:支持 4 维 Batch(batchA/A2/A3/A4)
  • 负载均衡:动态调整 tile 分配

【免费下载链接】ops-tensorops-tensor 是 CANN (Compute Architecture for Neural Networks)算子库中提供张量类计算的基础算子库,采用模块化设计,支持灵活的算子开发和管理。项目地址: https://gitcode.com/cann/ops-tensor

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

http://www.jsqmd.com/news/895890/

相关文章:

  • 构建多智能体系统核心:Agent2Agent交互层架构与实战
  • 用Matplotlib heatmap分析你的数据:从销售报表到用户行为矩阵的3个实战案例
  • Android TEE实战指南:从架构解析到安全应用开发
  • 3种方案深度解析:Windows Defender性能优化与安全组件管理
  • 3分钟快速上手:Switch手柄PC适配终极指南
  • 终极iOS应用自由指南:TrollInstallerX一键安装教程
  • 变压器漏感测量:从传统认知到仿真验证的实践洞察
  • LumiPi训练技术揭秘:LoRA在扩散变换器上的HDR训练方法
  • 本地部署语音AI助手:基于Whisper与LangChain的私有化智能体搭建指南
  • BetterJoy完整指南:5分钟让Switch手柄在PC上完美运行
  • 终极指南:如何快速解锁QQ音乐加密音频,免费转换为MP3/FLAC格式
  • Windows Defender彻底移除指南:专业系统安全组件管理工具详解
  • 思源宋体:如何用7款免费字体提升中文排版专业度
  • 如何用BetterNCM安装器5分钟解锁网易云音乐隐藏功能
  • CPU本地语音AI实战:Pocket Studio三模型对比与Docker部署指南
  • Nandi-Mini-600M模型架构深度解析:从Transformer到高效推理
  • 低代码平台表单设计器 unione-form-editor 组件 —— 二维码组件
  • 终极指南:如何用Keyboard Chatter Blocker免费解决机械键盘连击问题
  • CognitiveFusion2-4x7B-BF16推理优化终极指南:BF16精度与内存管理技巧详解
  • 5个简单步骤掌握HLS流媒体下载:HLS Downloader终极使用指南
  • 终极指南:如何用免费PlantUML编辑器快速绘制专业UML图表
  • 认知科学赋能LLM:23种提示工程技巧提升AI输出质量
  • 从感觉编程到规范驱动开发:AI时代软件工程的质量保障实践
  • 从用量看板观察Taotoken按Token计费带来的成本透明度
  • 猫抓浏览器扩展终极指南:三步轻松下载网页视频资源
  • 3步搞定Unity游戏去马赛克:UniversalUnityDemosaics终极指南
  • LumiPic与LumiVid对比分析:单图像与视频HDR生成技术的终极指南 [特殊字符]
  • 装修公司哪家好?陕西峰淘装饰,全包套餐 700–1200 元 /㎡ - myqiye
  • 跨平台流媒体下载终极指南:N_m3u8DL-RE深度解析
  • 3步终极方案:用Mac Mouse Fix让普通鼠标在macOS上超越触控板!