当前位置: 首页 > news >正文

CANN/catlass Block MMAD开发详解

Block MMAD 代码开发详解

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

1. Block MMAD 概述

Block MMAD(Block Matrix Multiply-Add)是CATLASS模板库中负责块级矩阵乘法的核心组件,位于计算架构的中间层。它上接Kernel层,下接Tile层,负责将全局内存(GM)中的数据高效地加载到本地内存(L1/L0),并调度Tile级别的矩阵乘法运算。

Block MMAD采用高度模块化和模板化的设计,支持多种调度策略、Tile形状和数据类型,能够灵活适应不同的硬件架构和计算需求,本文将以BlockMmadPingpong为例详细讲解。

2. 模板组装机制

Block MMAD实现基于以下基础模板结构:

template < class DispatchPolicy, class L1TileShape, class L0TileShape, class AType, class BType, class CType, class BiasType = void, class TileCopy = Gemm::Tile::TileCopy<typename DispatchPolicy::ArchTag, AType, BType, CType, BiasType>, class TileMmad = Gemm::Tile::TileMmad<typename DispatchPolicy::ArchTag, AType, BType, BiasType> > struct BlockMmad { // 类型定义和核心实现 };

2.1 核心模板参数

参数名描述
DispatchPolicy调度策略,控制计算任务的分配和执行流程
L1TileShapeL1缓存级别Tile的形状,定义M、N、K三个维度的大小
L0TileShapeL0缓存级别Tile的形状,定义M、N、K三个维度的大小
AType/BType/CType输入矩阵A、B和输出矩阵C的数据类型和布局信息
BiasType偏置数据类型(可选)
TileCopyTile级别的数据拷贝组件,负责不同内存层级间的数据传输
TileMmadTile级别的矩阵乘法组件,负责实际的计算操作

2.2 类型导出

Block MMAD通过类型导出系统建立统一的类型接口,方便上层组件使用:

public: // Type Aliases using DispatchPolicy = MmadAtlasA2Pingpong<ENABLE_UNIT_FLAG_>; using ArchTag = typename DispatchPolicy::ArchTag; using L1TileShape = L1TileShape_; using L0TileShape = L0TileShape_; using ElementA = typename AType_::Element; using LayoutA = typename AType_::Layout; using ElementB = typename BType_::Element; using LayoutB = typename BType_::Layout; using ElementC = typename CType_::Element; using LayoutC = typename CType_::Layout; using TileMmad = TileMmad_; using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator; using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; using LayoutCInL0 = layout::zN;

3. 内存管理与缓存设计

Block MMAD负责管理不同层级的内存,包括全局内存(GM)、L1缓存和L0缓存:

3.1 静态常量定义

static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; static constexpr uint32_t STAGES = DispatchPolicy::STAGES; static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES;

3.2 内存检查

// Check LayoutC static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!"); // Check L1TileShape static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); // Check L0TileShape static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!");

3.3 多阶段缓存设计

Block MMAD采用多阶段流水线设计,使用pingpong技术隐藏内存访问延迟:

protected: // Multi-stage tensors list AscendC::LocalTensor<ElementA> l1ATensorList[STAGES]; AscendC::LocalTensor<ElementB> l1BTensorList[STAGES]; AscendC::LocalTensor<ElementA> l0ATensorList[STAGES]; AscendC::LocalTensor<ElementB> l0BTensorList[STAGES]; AscendC::LocalTensor<ElementAccumulator> l0CTensor; // Multi-stage event id list int32_t l1AEventList[STAGES]; int32_t l1BEventList[STAGES]; int32_t l0AEventList[STAGES]; int32_t l0BEventList[STAGES]; // The id of current stage uint32_t l1ListId{0}; uint32_t l0AListId{0}; uint32_t l0BListId{0};

4. 核心接口实现

4.1 构造函数

CATLASS_DEVICE BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0) { uint32_t l1AOffset = l1BufAddrStart; uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; // Init buffers for (uint32_t i = 0; i < STAGES; i++) { // Assign L1/L0A/L0B space for each stages l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_SIZE * i); l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_SIZE * i); l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_PINGPONG_BUF_SIZE * i); l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_PINGPONG_BUF_SIZE * i); // Assign event ID for each stages l1AEventList[i] = i; l1BEventList[i] = i + STAGES; l0AEventList[i] = i; l0BEventList[i] = i + STAGES; // The event id that needs to be set before the loop AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]); AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]); AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]); AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]); } l0CTensor = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(0); AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0); }

4.2 析构函数

CATLASS_DEVICE ~BlockMmad() { for (uint32_t i = 0; i < STAGES; i++) { AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]); AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]); AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]); AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]); } AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0); }

4.3 核心计算接口 operator()

operator()是Block MMAD的核心接口,负责执行块级矩阵乘法:

CATLASS_DEVICE void operator()( AscendC::GlobalTensor<ElementA> const &gmA, LayoutA const &layoutA, AscendC::GlobalTensor<ElementB> const &gmB, LayoutB const &layoutB, AscendC::GlobalTensor<ElementC> const &gmC, LayoutC const &layoutC, GemmCoord const &actualShape) { // 1. 初始化参数和布局 // 2. 预加载第一批数据到L1缓存 // 3. 主循环: // a. 预加载下一批数据到L1缓存 // b. 将当前L1缓存中的数据加载到L0缓存 // c. 执行Tile级矩阵乘法 // d. 管理多阶段流水线 // 4. 将结果从L0缓存写回全局内存 }

5. 执行流程分析

以BlockMmadPingpong为例,其执行流程如下:

5.1 数据预加载

// load first matrix A tile from GM to L1 AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]); auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]); // load first matrix B tile from GM to L1 AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]); auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB); AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);

5.2 主循环(多阶段流水线)

// main loop uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k()); for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; uint32_t kActualNext{0}; // 预加载下一批数据到L1缓存 if (kLoopIdx < kTileCount - 1) { // ... 预加载逻辑 ... } // 处理当前L1缓存中的数据 // Get L1 tensor for current stage auto l1ATensor = l1ATensorList[l1ListId]; auto l1BTensor = l1BTensorList[l1ListId]; // 将L1缓存中的数据加载到L0缓存并执行计算 // ... L0级处理逻辑 ... // 切换到下一个阶段 l1ListId = l1ListIdNext; kActual = kActualNext; }

5.3 L0级处理与计算

for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { // ... M维度循环 ... for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { // ... K维度循环 ... for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { // ... N维度循环 ... // 执行Tile级矩阵乘法 bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); uint8_t unitFlag = 0b00; // ... unitFlag设置 ... tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); } } }

5.4 结果写回

// copy block out LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); if constexpr (!ENABLE_UNIT_FLAG) { AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0); AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0); copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C); AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0); } else { copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11); }

6. 多阶段流水线与事件同步

Block MMAD使用事件驱动的同步机制确保多阶段流水线的正确执行:

// 等待事件 AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]); // 执行操作 copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); // 触发事件 AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);

7. 不同Block MMAD实现的共性与差异

CATLASS提供了多种Block MMAD实现,它们具有以下共性:

  1. 统一的模板接口:所有实现都基于相同的模板参数结构
  2. 多阶段流水线设计:都采用多阶段设计来隐藏内存访问延迟
  3. 事件驱动同步:都使用事件机制确保数据传输和计算的正确顺序
  4. 内存层级管理:都负责管理GM、L1和L0之间的数据传输

不同实现的主要差异在于:

  1. 调度策略:如pingpong、preload、streamk等不同的调度方式
  2. 硬件适配:针对不同硬件架构的优化实现
  3. 特殊功能支持:如量化、稀疏计算、偏置处理等
  4. 性能优化:不同的流水线深度和内存访问模式

8. 总结

Block MMAD是CATLASS模板库中连接Kernel层和Tile层的关键组件,它通过以下设计实现了高效的块级矩阵乘法:

  1. 高度模块化:通过模板参数组装不同的功能组件
  2. 多阶段流水线:使用pingpong技术隐藏内存访问延迟
  3. 事件驱动同步:确保数据传输和计算的正确顺序
  4. 自动内存管理:高效管理不同层级的内存资源
  5. 灵活适配:支持不同的硬件架构和计算需求

Block MMAD的设计体现了现代高性能计算库的核心思想,通过精心的流水线设计和内存管理,充分发挥硬件的计算能力,实现高效的矩阵乘法运算。

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

http://www.jsqmd.com/news/784094/

相关文章:

  • 2026年5月国内信号隔离器品牌TOP10大盘点 - 仪表人叶工
  • 扩散模型与多模态掩码的精准图像编辑技术
  • 技术人如何用工程化思维提升学术写作效率:从工具链到结构化思维
  • CANN/xla-npu BatchMatMul优化
  • FFmpeg QSV滤镜实战:解决`get_buffer() failed`报错的两种内存访问方案对比
  • CANNBot: RoPE预计算参考
  • Taotoken的API Key管理与访问控制功能实践分享
  • 2026 年活性炭箱厂家权威排行榜 TOP5 - 小艾信息发布
  • Dart factory构造函数避坑指南:和普通构造函数的5个关键区别与性能影响
  • ARM架构TLB操作与缓存锁定机制详解
  • CANN/pyasc API文档自动生成工具使用指南
  • AI医疗在非洲的落地实践:机遇、挑战与四步走策略
  • 2026 年生物滤池权威排行榜 TOP5 - 小艾信息发布
  • 高性能计算驱动可扩展AI:科学发现新范式与工程实践
  • StateLM:大语言模型长上下文管理的创新与实践
  • 2026 年挥发性有机物(VOCs)处理领域优质企业 TOP5 - 小艾信息发布
  • Arm Neoverse V3AE调试寄存器解析与调试技巧
  • 防晒霜哪个好?这6款高倍防晒防黑防水从不踩雷 - 全网最美
  • CANN/Ascend C按位与操作API
  • 构建AI模型开放框架:从可复现性到社区协作的完整指南
  • 西北企业画册设计印刷突围秘诀:松林森彩印如何用海德堡机器打破传统工厂交期魔咒 - 企业名录优选推荐
  • 从芬兰研究看儿童AI认知误区:三类典型误解与教学应对策略
  • 用Python手把手实现电力系统潮流计算(牛顿-拉夫逊法实战)
  • 做TK怕BGM侵权?10年海外MCN亲测!5个商用音乐网免费又安全,告别静音下架 - 拾光而行
  • TTC-RL技术解析:提升大语言模型推理准确率的实时强化学习方法
  • SlimeNexus:基于Spring Boot与Vue的Minecraft服务器一体化运维管理平台
  • AI智能体安全部署指南:从Docker容器化到权限控制实战
  • 3步搭建个人游戏云:Sunshine开源串流服务器彻底解放你的游戏硬件
  • 从太湖到北极:环境工程师带你用Python分析PFAS污染数据与时空分布
  • 西安不干胶标签定制哪家强?2026年陕西印刷厂一站式服务能力横评 - 企业名录优选推荐