当前位置: 首页 > news >正文

CUDA高性能计算系列10:实战手写深度学习算子(Softmax)

CUDA高性能计算系列10:实战手写深度学习算子(Softmax)

摘要:纸上得来终觉浅,绝知此事要躬行。学了这么多优化技巧,是时候检验真功夫了。本篇我们将深入深度学习中最常见的算子之一——Softmax。看似简单的公式背后,隐藏着数值溢出的陷阱和并行归约的挑战。我们将手写一个能够与 PyTorch 原生性能抗衡的 Softmax Kernel。


1. Softmax 的数学原理与挑战

Softmax 函数将一个向量x xx映射为概率分布y yy
y i = e x i ∑ j e x j y_i = \frac{e^{x_i}}{\sum_{j} e^{x_j}}yi=jexjexi

1.1 数值稳定性问题 (Numerical Stability)

直接计算e x i e^{x_i}exi非常危险。
如果x i = 100 x_i = 100xi=100,则e 100 ≈ 2.6 × 10 43 e^{100} \approx 2.6 \times 10^{43}e1002.6×1043,这在 FP32 范围内没问题。
但如果x i = 1000 x_i = 1000xi=1000,则e 1000 → ∞ e^{1000} \to \inftye1000(Inf),导致 NaN 错误。

解决方案:减去最大值。
y i = e x i − max ⁡ ( x ) ∑ j e x j − max ⁡ ( x ) y_i = \frac{e^{x_i - \max(x)}}{\sum_{j} e^{x_j - \max(x)}}yi=jexjmax(x)eximax(x)
这样所有指数的指数项都在( − ∞ , 0 ] (-\infty, 0](,0]之间,结果在( 0 , 1 ] (0, 1](0,1]之间,永远不会上溢。

1.2 计算流程

这就将一个 Softmax 变成了三个阶段的计算:

  1. Reduce Max: 找到当前行的最大值m mm
  2. Reduce Sum: 计算S = ∑ e x i − m S = \sum e^{x_i - m}S=exim
  3. Element-wise Update: 计算y i = e x i − m / S y_i = e^{x_i - m} / Syi=exim/S

这就意味着我们需要遍历数据三次!如何高效地由 GPU 完成?


2. 架构设计:Grid, Block, Warp

假设输入张量形状为[Batch_Size, Dim]
通常Batch_Size很大,Dim变化范围广(从 100 到 10000+)。

2.1 策略:一行一个 Block

  • Grid Size:Batch_Size。每个 Block 处理一行数据。
  • Block Size: 256 或 1024。

如果Dim很小(< 1024),一个 Block 刚好能装下,直接用 Shared Memory 归约。
如果Dim很大,Block 需要循环处理(Grid-Stride Loop 变体)。


3. Kernel 实现:One-Pass 还是 Three-Pass?

为了教学清晰,我们先实现一个标准的Three-Pass逻辑,但在同一个 Kernel 内完成(避免多次启动 Kernel 的开销)。

#include<cuda_runtime.h>#include<math.h>// 辅助函数:Warp 内求最大值__device__floatwarpReduceMax(floatval){for(intoffset=16;offset>0;offset/=2)val=fmaxf(val,__shfl_down_sync(0xffffffff,val,offset));returnval;}// 辅助函数:Warp 内求和__device__floatwarpReduceSum(floatval){for(intoffset=16;offset>0;offset/=2)val+=__shfl_down_sync(0xffffffff,val,offset);returnval;}__global__voidsoftmax_kernel(float*input,float*output,intdim){// 1. 设置索引// blockIdx.x 对应 batch 维度(行号)introw_idx=blockIdx.x;// 指向当前行的起始地址float*row_input=input+row_idx*dim;float*row_output=output+row_idx*dim;// 2. 阶段一:求最大值 (Reduce Max)floatmax_val=-INFINITY;// 循环处理,防止 dim > blockDim.xfor(inti=threadIdx.x;i<dim;i+=blockDim.x){max_val=fmaxf(max_val,row_input[i]);}// Block 内规约最大值// 这里使用 Shared Memory 进行 Block 级规约(简化版,假设 Block=256,1个Warp处理不了)// 为了简单,我们只展示 Warp 级规约逻辑,实际需配合 Shared Memorymax_val=warpReduceMax(max_val);// 通过 Shared Memory 广播最大值给所有线程__shared__floats_max;if(threadIdx.x==0)s_max=max_val;__syncthreads();max_val=s_max;// 3. 阶段二:求指数和 (Reduce Sum)floatsum=0.0f;for(inti=threadIdx.x;i<dim;i+=blockDim.x){sum+=expf(row_input[i]-max_val);}sum=warpReduceSum(sum);__shared__floats_sum;if(threadIdx.x==0)s_sum=sum;__syncthreads();sum=s_sum;// 4. 阶段三:计算最终结果for(inti=threadIdx.x;i<dim;i+=blockDim.x){row_output[i]=expf(row_input[i]-max_val)/sum;}}

3.1 深度优化:Online Softmax

传统的 Softmax 需要遍历数据 3 次(Max -> Sum -> Update)。
有一种算法叫Online Softmax,利用数学技巧只需要遍历 2 次甚至更少。

公式推导:
维护当前的局部最大值m mm和局部和d dd
当遇到一个新的元素x xx时:

  • x > m x > mx>mm n e w = x m_{new} = xmnew=x,d n e w = d × e m − x + 1 d_{new} = d \times e^{m - x} + 1dnew=d×emx+1
  • x ≤ m x \le mxmm n e w = m m_{new} = mmnew=m,d n e w = d + e x − m d_{new} = d + e^{x - m}dnew=d+exm

这种方法可以在一次遍历中同时更新最大值和和,极大减少 Global Memory 访问。


4. 性能瓶颈分析

  1. Memory Bound: Softmax 是典型的Element-wise操作,计算量很小(也就 exp 和 div),主要时间都花在读写内存上。
  2. 优化方向
    • 确保 Global Memory 的合并访问(我们已经做到了,行内元素是连续的)。
    • 尽量把数据留在寄存器或 Shared Memory 中,避免重复读取 input。

5. 向量化读取 (Vectorized Load)

在处理 FP32 时,我们可以使用float4类型,一次读取 128 bit(4 个 float)。这能显著提高带宽利用率,减少指令数。

// 重新解释指针float4*vec_input=reinterpret_cast<float4*>(row_input);// 每次处理 4 个元素float4 data=vec_input[threadIdx.x];// ... 分别处理 data.x, data.y, data.z, data.w ...

限制:要求Dim必须是 4 的倍数,且地址必须对齐。实际工程中需要处理边界条件。


6. 总结与下篇预告

编写一个高性能的 Softmax 算子,不仅需要 CUDA 编程技巧(Shared Memory, Warp Shuffle),还需要深厚的数值分析功底(防止溢出)和算法优化思路(Online Softmax)。

至此,我们的 Kernel 代码已经能够跑在 GPU 上了。但是,怎么让 Python 里的 PyTorch 调用它呢?难道每次都要把数据存成文件,用 C++ 跑完再读回来吗?

当然不是!
下一篇CUDA系列11_PyTorch自定义C++扩展(Binding),我们将打通任督二脉,教你使用torch.utils.cpp_extension将我们写的 CUDA Kernel 编译成 Python 模块。届时,你只需要import my_cuda_ops,就能在 Python 里直接享用你亲手打造的高性能算子!


参考文献

  1. Milakov, M., & Gimelshein, N.Online Normalizer Calculation for Softmax. arXiv:1805.02867.
  2. OneFlow Team.How to Implement an Efficient Softmax Kernel.
http://www.jsqmd.com/news/222301/

相关文章:

  • 医疗用AutoGluon自动建模
  • 大规模数据检索优化:elasticsearch官网核心要点
  • 从0到1搭建实时日志监控系统:基于WebSocket + Elasticsearch的实战方案
  • 协同过滤性能优化技巧:高并发场景应用
  • 做“自适应PINN”的赢麻了,连发TOP刊的感觉太爽了!
  • 毕业论文AI怎么查重?我的血泪经验+实用工具大公开
  • 通俗解释nmodbus4在.NET Framework与Core的区别
  • 【气动学】最优控制理论的归导定律和撞击角控制【含Matlab源码 14887期】含报告
  • 如何高效部署专业翻译模型?HY-MT1.5-7B镜像一键启动指南
  • AVD无法运行?一文说清Intel HAXM安装全流程
  • 零基础掌握cp2102与Modbus协议的工业通信对接
  • Neo4j中的Cypher查询优化技巧
  • 一文说清电路仿真circuits网页版中的反馈电路原理
  • 工业机器人通信前的USB转232驱动安装准备指南
  • 解决NumPy ImportError问题的实践与思考
  • 【图像隐写】快速四元数通用极坐标复指数变换的彩色图像零水印【含Matlab源码 14889期】
  • CANFD协议仲裁场解析:核心要点说明
  • 实战案例:基于车载雷达模块的CANFD与CAN对比
  • Linux tcpdump工具的使用
  • 零基础必看,1小时速通 从JavaSE到SpringBoot框架,搞定企业刚需技术!
  • CUDA 11.0 共享库缺失:环境配置实战案例解析
  • CES观察|AI硬件迎来黄金时代,中国机器人“进场打工”
  • 计算降雨间隔:使用purrr包的优雅方法
  • 统一监控多个ES集群:可视化管理工具实战解析
  • 基于Java+SpringBoot+SSM智能水产养殖管理系统(源码+LW+调试文档+讲解等)/智能渔业养殖管理系统/水产养殖智能化系统/水产智能管理平台/智能水产养殖技术/水产养殖监控管理系统
  • 机动绞磨机,长云科技电信工程牵引绞磨
  • vivado2023.2下载安装教程操作指南:专为Artix-7优化
  • AUTOSAR中Vector工具链的DBC与ARXML转换实战案例
  • 小红书Java面试被问:TCC事务的悬挂、空回滚问题解决方案
  • ChatGPT的尽头是A2UI?谷歌重磅新标准:让AI学会“做界面”,重新定义人机交互!