【Tilelang入门】Tilelang Puzzles 08
Tillang Puzzles
一个开源仓库https://github.com/tile-ai/tilelang-puzzles/tree/main
给出用tilelang实现经典算子的例子,附带讲解。分为10个puzzle,每个问题都有待补全文件,和参考实现,以及文字讲解。
采用循序渐进的思路,难度逐渐递增,01-05熟悉语法,06-09实现经典算子,10为挑战复杂实战算子
08 Matirx
矩阵乘法是整个算子优化的核心,大部分经典算子都可以规约到矩阵乘法,比如前一节的注意力里有QKTQK^TQKT,前向传播有参数*输入,反向传播的求偏导也是矩阵乘法,卷积通过im2col转化后也能变成矩阵乘法。
GEMV
先来个基础的,矩阵乘向量,可以看成矩阵乘法的特殊情况,N=1
定义
foriinrange(M):ACC=0# float32 累加器forkinrange(K):ACC+=A[i,k]*B[k]C[i]=ACC# 转换回 float16实际上也可以看成规约求和的特殊情况,看成带权规约,普通规约权重都是1,这里的权重是B[k]
# Reduce Sum (Puzzle 05)foriinrange(N):C[i]=sum(A[i,:])# GEMV (Puzzle 08)foriinrange(M):C[i]=sum(A[i,:]*B[:])# 加权求和baseline
defref_gemv(A:torch.Tensor,B:torch.Tensor):assert A.shape==(M,K)assert B.shape==(K,)assert A.dtype==B.dtype==torch.float16returntorch.matmul(input=A,other=B)# 返回[M,]C_local = T.alloc_fragment((BLOCK_M,), accum_dtype),累加类型使用fp32,不同于输入类型fp16,因为fp16不管是精度还是值域都太小了,矩阵乘法有乘法,有累加,数值很大,用32才能保证不溢出+精度过关AB_local[i, j] = A_local[i, j].astype(accum_dtype) * B_local[j].astype(accum_dtype)按前面说的,看成带权规约,先计算乘上权重后的结果。由于输入是fp16,还想保证精度,计算时先显式转成fp32,类似于cpp里的ans += 1ll * x * y- 然后规约
T.reduce_sum(AB_local, C_local, dim=1, clear=False)
@tilelang.jitdeftl_gemv(A,B,BLOCK_M:int,BLOCK_K:int):M,K=T.const("M, K")dtype=T.float16 accum_dtype=T.float32 A:T.Tensor((M,K),dtype)B:T.Tensor((K,),dtype)C=T.empty((M,),dtype)# TODO: Implement this functionwithT.Kernel(T.ceildiv(M,BLOCK_M),threads=128)asbx:A_local=T.alloc_fragment((BLOCK_M,BLOCK_K),dtype)B_local=T.alloc_fragment((BLOCK_K,),dtype)C_local=T.alloc_fragment((BLOCK_M,),accum_dtype)AB_local=T.alloc_fragment((BLOCK_M,BLOCK_K),accum_dtype)T.clear(C_local)forbkinT.Serial(T.ceildiv(K,BLOCK_K)):T.copy(A[bx*BLOCK_M,bk*BLOCK_K],A_local)T.copy(B[bk*BLOCK_K],B_local)fori,jinT.Parallel(BLOCK_M,BLOCK_K):AB_local[i,j]=A_local[i,j].astype(accum_dtype)*B_local[j].astype(accum_dtype)T.reduce_sum(AB_local,C_local,dim=1,clear=False)T.copy(C_local,C[bx*BLOCK_M])returnC
性能很差,这算是不叫暴力的做法
朴素GEMM
T.gemm(A_local, B_local, C_local)和前面唯一的区别,把手动乘上权重,再逐行规约,改成调用gemm接口计算一个块的结果了,只需传入两个输入矩阵,一个接收矩阵。
@tilelang.jitdeftl_matmul_naive(A,B,BLOCK_M:int,BLOCK_N:int,BLOCK_K:int):M,N,K=T.const("M, N, K")dtype=T.float16 accum_dtype=T.float32 A:T.Tensor((M,K),dtype)B:T.Tensor((K,N),dtype)C=T.empty((M,N),dtype)# TODO: Implement this functionwithT.Kernel(T.ceildiv(M,BLOCK_M),T.ceildiv(N,BLOCK_N),threads=128)as(bx,by):A_local=T.alloc_fragment((BLOCK_M,BLOCK_K),dtype)B_local=T.alloc_fragment((BLOCK_K,BLOCK_N),dtype)C_local=T.alloc_fragment((BLOCK_M,BLOCK_N),accum_dtype)T.clear(C_local)forbkinT.Serial(T.ceildiv(K,BLOCK_K)):T.copy(A[bx*BLOCK_M,bk*BLOCK_K],A_local)T.copy(B[bk*BLOCK_K,by*BLOCK_N],B_local)T.gemm(A_local,B_local,C_local)T.copy(C_local,C[bx*BLOCK_M,by*BLOCK_N])returnC
性能还是很差,而且看起来5x比前面的GEMV的3x左右还要差?难道gemm还不如手动规约高效?显然不是,因为GEMM两个输入都是矩阵,计算规模更大了,那么效率差距会被放大。实际GEMM肯定效率是比手动规约高的。
GEMM内部一般会直接调用MMA接口,使用Tensor Core计算。Tensor Core是矩阵计算专用单元,只能级算特定大小的矩阵乘法,不灵活,但是计算吞吐量大;前面的reduce和拷贝这些操作,都是CUDA Core执行的,可以执行通用计算,但是效率低。所以调用Tensor Core的GEMM接口一般性能肯定比CUDA Core的reduce高。
这里我们和torch还有差距,说明还有优化没用上。
优化版GEMM
B_local = T.alloc_shared((BLOCK_K, BLOCK_N), dtype)这里把张量从寄存器换到了共享内存上,明明寄存器更快,这是为什么?因为寄存器资源是很有限的,基本装下一个(BLOCK_M, BLOCK_K)大的张量就快满了,两个就不够了,而溢出部分的数据,会被直接存到全局内存,全局内存的延迟是最高的,整体效率被这个环节完全拖慢了,什么优化都没用了。所以,只有最频繁用到的累加数组,我们考虑安排在寄存器上,两个输入分块安排在共享内存,共享内存一般很大,容纳多个块都没问题。同时,访问速度也还可以接受。for bk in T.Pipelined(T.ceildiv(K, BLOCK_K), num_stages=3):又是tilelang的一个强大接口,这里可以在并行循环里,增加一个参数num_stage,指定流水线级数。就可以把这个循环流水线化!理论上合适的流水阶段划分,可以实现等同于级数的加速比!- 这里虽然轻飘飘的一行,实际内部优化思想非常重要,注意到流水线建立前,每一轮的循环执行的是,先拷贝,再计算。但是我们前面提到过,现代GPU的内存带宽远小于计算吞吐量,也就是说大部分时间,计算核心都处于阻塞,等待内存搬运,这正是适合流水线优化的地方,可以把搬运和计算解耦,流水线一个阶段负责搬运,一个阶段负责计算,这样搬运的时候也可以计算,大大提升效率,整体瓶颈只取决于最慢的部分,也就是搬运,计算延迟几乎被完全隐藏了。
@tilelang.jitdeftl_matmul_opt(A,B,BLOCK_M:int,BLOCK_N:int,BLOCK_K:int):M,N,K=T.const("M, N, K")dtype=T.float16 accum_dtype=T.float32 A:T.Tensor((M,K),dtype)B:T.Tensor((K,N),dtype)C=T.empty((M,N),dtype)# TODO: Implement this functionwithT.Kernel(T.ceildiv(M,BLOCK_M),T.ceildiv(N,BLOCK_N),threads=128)as(bx,by):A_local=T.alloc_shared((BLOCK_M,BLOCK_K),dtype)B_local=T.alloc_shared((BLOCK_K,BLOCK_N),dtype)C_local=T.alloc_fragment((BLOCK_M,BLOCK_N),accum_dtype)T.clear(C_local)forbkinT.Pipelined(T.ceildiv(K,BLOCK_K),num_stages=3):T.copy(A[bx*BLOCK_M,bk*BLOCK_K],A_local)T.copy(B[bk*BLOCK_K,by*BLOCK_N],B_local)T.gemm(A_local,B_local,C_local)T.copy(C_local,C[bx*BLOCK_M,by*BLOCK_N])returnC优化都用上后和torch实现差的不多了。
另外来验证一下,所有张量都申请在寄存器上会不会导致性能退化。可以看到几乎退化到朴素GEMM版本了,这是合理的,因为朴素版本就是直接访问全局内存,这里内存溢出后,张量也是存在全局内存上的,访问延迟自然和直接存在全局内存上相近。
