DeepGEMM:统一高性能张量核心内核库,多功能升级提升性能
DeepGEMM:统一的高性能张量核心内核库
DeepGEMM 是一个统一的高性能张量核心内核库,它将现代大语言模型的关键计算原语整合到一个统一的 CUDA 代码库中,这些原语包括通用矩阵乘法(GEMMs,支持 FP8、FP4、BF16)、融合专家混合(MoE)与重叠通信(Mega MoE)、闪电索引器的多查询注意力(MQA)评分、超连接(HyperConnection,HC)等。所有内核都通过轻量级即时编译(Just - In - Time,JIT)模块在运行时编译,安装过程中无需进行 CUDA 编译。
DeepGEMM 借鉴了 CUTLASS 和 CuTe 的一些概念,但避免了对它们的模板或代数的过度依赖。该库设计简洁,核心内核函数数量有限,是学习 NVIDIA GPU 内核优化技术的优质资源。尽管设计轻量,但 DeepGEMM 在各种矩阵形状下的性能与经过专家调优的库相当,甚至更优。
新闻动态
2026 年 4 月 16 日:新增 Mega MoE、FP8xFP4 GEMM、FP4 索引器、程序依赖启动(PDL)、更快的 JIT 编译等功能。性能对比后续公布,详情见 #304。
2025 年 9 月 28 日:DeepGEMM 现在支持 DeepSeek v3.2 闪电索引器的评分内核(加权 ReLU MQA 对数),详情见 #200。
2025 年 7 月 20 日:DeepGEMM 同时支持 SM90 和 SM100 架构,对低 CPU 开销的 JIT CPP 模块进行了全面重构,禁用了 NVRTC 和编译后 SASS 优化,后续将支持 NVRTC。由于 NVCC 12.9 会自动进行 FFMA 交错,不再支持所有编译后优化,详情见 #112。
2025 年 5 月 14 日:DeepGEMM 现在提供用于密集和 MoE 反向传播的权重梯度内核,详情见 #95。
2025 年 5 月 7 日:DeepGEMM 现在支持 NVRTC,编译速度最高可提升 10 倍,详情见 #94。使用 `DG_JIT_USE_NVRTC = 1` 启用(某些情况下可能会有性能损失)。
2025 年 4 月 18 日:DeepGEMM 在 H800 上实现了高达 1550 TFLOPS 的性能,详情见 #74、#78、#81、#86 和 340d988。
快速开始
要求
- NVIDIA SM90 或 SM100 架构的 GPU
- Python 3.8 或更高版本
- 支持 C++20 的编译器
- CUDA 工具包:
SM90 建议使用 CUDA 12.3 或更高版本,为获得最佳性能,强烈推荐 12.9 或更高版本
SM100 建议使用 CUDA 12.9 或更高版本
- PyTorch 2.1 或更高版本
- CUTLASS 4.0 或更高版本(可通过 Git 子模块克隆)
- {fmt} 库(可通过 Git 子模块克隆)
开发
```bash
# 必须克隆子模块
git clone --recursive git@github.com:deepseek - ai/DeepGEMM.git
cd DeepGEMM
# 链接一些必要的头文件并构建 CPP JIT 模块
cat develop.sh
./develop.sh
```
安装
```bash
cat install.sh
./install.sh
```
然后,在你的 Python 项目中导入 `deep_gemm` 即可开始使用!
接口说明
通用信息
该库为 NVIDIA GPU 提供优化的 GEMM 内核,命名规则为 `D = C + A @ B`。输入形状布局为 NT(A 不转置,B 转置)。SM90 实现仅支持 NT 内存布局(行主序、列主序),而 SM100 实现支持所有内存布局(NT、TN、NN、TT)。例如,`fp8_gemm_nt` 会执行 `D = C + A @ B.T`。
对于两种架构,左侧缩放因子都需要具有 TMA 对齐和转置的布局。SM90 和 SM100 的缩放因子数据格式不同:SM90 需要 FP32 格式的缩放因子,SM100 需要打包的 UE8M0 格式,即将 4 个 UE8M0 打包成一个 `torch.int`。
请注意,输入转置或 FP8 转换等操作需要用户单独处理,请自行实现或将其融合到之前的内核中。虽然库中提供了一些简单的 PyTorch 实用函数,但可能会导致性能下降,库的主要重点是优化 GEMM 内核本身。
普通密集 GEMMs(非分组)
要执行基本的非分组 FP8 GEMM,调用 `fp8_gemm_{nt, nn, tn, tt}` 函数,具体详情请参考函数文档。
分组 GEMMs(连续布局)
与 CUTLASS 中的传统分组 GEMMs 不同,DeepGEMM 仅对 M 轴进行分组,N 和 K 必须保持固定。这种设计适用于 MoE 模型中专家共享相同形状的场景。在训练前向传播或推理预填充阶段,每个专家可能处理不同数量的令牌,我们将这些令牌连接成一个张量,即“连续”布局。请注意,每个专家段必须与 GEMM M 块大小对齐(`get_mk_alignment_for_contiguous_layout()`)。更多信息请参考 `m_grouped_fp8_gemm_{nt, nn}_contiguous` 函数文档。我们还提供了用于 MoE 权重反向传播的 K 轴分组 API(M 和 N 必须保持固定),详情请参考 `k_grouped_fp8_gemm_tn_contiguous`。
分组 GEMMs(掩码布局)
在推理解码阶段,当启用 CUDA 图且 CPU 不知道每个专家接收的令牌数量时,我们支持掩码分组 GEMMs。通过提供掩码张量,内核仅计算有效部分。使用 `m_grouped_fp8_gemm_nt_masked` 并参考相关文档。一个示例用法是使用 DeepEP 低延迟内核的输出作为输入。
V3.2 索引器的 MQA 内核
该内核家族有两个版本,非分页(用于预填充)和分页(用于解码)。以非分页版本 `fp8_mqa_logits` 为例,它有 6 个输入:
- `q`:形状为 `[seq_len, num_heads, head_dim]` 的 E4M3 张量
- `kv`:形状为 `[seq_len_kv, head_dim]` 的 E4M3 张量,带有形状为 `[seq_len_kv]` 的浮点缩放因子
- `weights`:形状为 `[seq_len, num_heads]` 的浮点张量
- `cu_seq_len_k_start` 和 `cu_seq_len_k_end`:形状为 `[seq_len]` 的整数张量
- `clean_logits`:是否将未填充的对数清零为 `- inf`
输出张量形状为 `[seq_len, seq_len_kv]`,表示令牌到令牌的对数。对于 `q` 中的每个令牌 `i`,它将遍历 `[cu_seq_len_k_start[i], cu_seq_len_k_end[i])` 中的所有令牌 `j`,并计算对数 `out[i, j]` 如下:
```python
kv_j = kv[0][j, :] * kv[1][j].unsqueeze(1) # [head_dim]
out_ij = q[i, :, :] @ kv_j # [num_heads]
out_ij = out_ij.relu() * weights[i, :] # [num_heads]
out_ij = out_ij.sum() # 标量
```
更多详情和分页版本 `fp8_paged_mqa_logits` 请参考 `tests/test_attention.py`。
Mega MoE
Mega MoE 将专家并行(EP)调度、线性层 1(FP8xFP4)、SwiGLU、线性层 2(FP8xFP4)和 EP 合并融合到一个巨型内核中,实现了 NVLink 通信和张量核心计算的重叠。它需要使用对称内存进行多进程启动。
使用方法:
```python
# 分配对称内存缓冲区
# 注意:需要 PyTorch >= 2.9
buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden
)
# 将权重(FP4 带 UE8M0 缩放因子)转换为所需布局
transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
# 在每次调用前将输入复制到缓冲区
# 你可以将这些操作融合到之前的内核中
buffer.x[:num_tokens].copy_(x_fp8)
buffer.x_sf[:num_tokens].copy_(x_sf)
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)
# 运行融合的 Mega MoE 内核
y = torch.empty((num_tokens, hidden), dtype = torch.bfloat16, device = 'cuda')
deep_gemm.fp8_fp4_mega_moe(y, transformed_l1, transformed_l2, buffer)
```
完整的多进程设置和基准测试示例请参考 `tests/test_mega_moe.py`。
实用函数
除了上述内核,库还提供了一些实用函数:
- `deep_gemm.set_num_sms / get_num_sms`:设置/获取要使用的最大 SM 数量
- `deep_gemm.set_tc_util / get_tc_util`:设置/获取近似的张量核心利用率
- `deep_gemm.set_pdl / get_pdl`:启用/禁用程序依赖启动(PDL)
- `deep_gemm.set_mk_alignment_for_contiguous_layout / get_mk_alignment_for_contiguous_layout`:设置/获取连续布局的组级 M/K 对齐
- `deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout`:获取理论最小 M/K 对齐
- `deep_gemm.set_ignore_compile_dims`:配置 JIT 编译时要忽略的维度
- `deep_gemm.set_block_size_multiple_of`:将块大小限制为给定值的倍数
- `deep_gemm.transform_sf_into_required_layout`:将缩放因子转换为所需布局
- `deep_gemm.get_tma_aligned_size`:获取所需的 TMA 对齐大小
- `deep_gemm.get_mn_major_tma_aligned_tensor`:获取 MN 主序 TMA 对齐的张量
- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`:获取 MN 主序 TMA 对齐打包成 UE8M0 的张量
- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`:K 分组 GEMM 打包内核
环境变量
通用
- `DG_JIT_DEBUG`:0 或 1,打印 JIT 调试信息,默认为 0
- `DG_PRINT_CONFIGS`:0 或 1,打印每个形状的选定配置,默认为 0
JIT 缓存
- `DG_JIT_CACHE_DIR`:字符串,编译内核的缓存目录,默认为 `$HOME/.deep_gemm`
编译器选择
- `DG_JIT_USE_NVRTC`:0 或 1,使用 NVRTC 代替 NVCC(编译速度更快,某些情况下可能性能较低),默认为 0
- `DG_JIT_NVCC_COMPILER`:字符串,NVCC 编译器路径;默认为 `torch.utils.cpp_extension.CUDA_HOME`
- `DG_JIT_CPP_STANDARD`:整数,C++ 标准版本,默认为 20
编译器输出
- `DG_JIT_PRINT_COMPILER_COMMAND`:0 或 1,打印编译命令,默认为 0
- `DG_JIT_PTXAS_VERBOSE`:0 或 1,显示详细的 PTXAS 输出,默认为 0
- `DG_JIT_PTXAS_CHECK`:0 或 1,断言编译内核中无本地内存使用,默认为 0
- `DG_JIT_PRINT_LOAD_TIME`:0 或 1,打印内核加载时间,默认为 0
调试和性能分析
- `DG_JIT_WITH_LINEINFO`:0 或 1,为性能分析工具嵌入源代码行信息,默认为 0
- `DG_JIT_DUMP_ASM`:0 或 1,转储 PTX 和 SASS,默认为 0
- `DG_JIT_DUMP_PTX`:0 或 1,转储 PTX 输出,默认为 0
- `DG_JIT_DUMP_SASS`:0 或 1,转储 SASS 输出,默认为 0
- `DG_COMM_KERNEL_DEBUG`:0 或 1,在每次 Mega MoE 调用前将对称缓冲区清零以进行调试,默认为 0
- `DG_USE_NVIDIA_TOOLS`:0 或 1,在外部 NVIDIA 工具下运行时跳过内部性能分析,默认为 0
构建选项
- `DG_SKIP_CUDA_BUILD`:0 或 1,安装过程中跳过 CUDA 扩展构建,默认为 0
- `DG_FORCE_BUILD`:0 或 1,强制本地构建而不是下载预构建的 wheel 文件,默认为 0
- `DG_JIT_USE_RUNTIME_API`:0 或 1,使用 CUDA 运行时 API 加载内核(需要 CUDA 运行时 >= 12.8),默认为 0
更多示例和详细信息,请参考测试代码或查看相应的 Python 文档。
致谢
DeepGEMM 受到 CUTLASS 项目的启发,感谢并尊重开发者们!
许可证
此代码库根据 MIT 许可证发布。
引用
```bibtex
@misc{deepgemm2025,
title={DeepGEMM: clean and efficient BLAS kernel library on GPU},
author={Chenggang Zhao and Zhean Xu and Liang Zhao and Jiashi Li and Chenhao Xu and Anyi Xu and Shengyu Liu and Kexing Zhou and Kuai Yu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek - ai/DeepGEMM}},
}
```
