DeepSeek-R1 在 CANN 上的推理部署
本文基于昇腾CANN和昇腾NPU,围绕 cann-recipes-infer 仓库的相关技术展开。
DeepSeek-R1 是个 MoE 模型——671B 总参数但每次推理只激活 37B。这对推理系统是个结构性的挑战:MoE 的路由选择和 Expert 调度依赖通信,CANN 的集合通信库 HCCL 和单边通信库 hixl 构成了 MoE 推理的通信底座。
MoE 的推理计算图
# DeepSeek-R1 的 MoE 层——路由 + Expert 计算classMoELayer(torch.nn.Module):def__init__(self,num_experts=256,top_k=8,expert_dim=4096):super().__init__()self.num_experts=num_experts self.top_k=top_k# 路由 Gate——决定每个 Token 发到哪些 Expertself.gate=torch.nn.Linear(expert_dim,num_experts,bias=False)# Expert 网络——256 个 FFN,每个是 2 层 MLPself.experts=torch.nn.ModuleList([torch.nn.Sequential(torch.nn.Linear(expert_dim,2*expert_dim*4),# SwiGLUtorch.nn.Linear(2*expert_dim*4,expert_dim),)for_inrange(num_experts)])defforward(self,x):# x: [batch * seq_len, expert_dim]B,D=x.shape# Step 1: Gate 算路由分数gate_logits=self.gate(x)# [B, 256]gate_scores=torch.softmax(gate_logits,dim=-1)# Step 2: Top-K 选择——每个 Token 选 8 个 Experttopk_weights,topk_indices=torch.topk(gate_scores,self.top_k,dim=-1)topk_weights=topk_weights/topk_weights.sum(dim=-1,keepdim=True)# Step 3: 分发 Token 到对应 Expert# 每个 Expert 收到的 Token 集合# 这步需要 All-to-All 通信——Expert 分布在不同卡上dispatched=self.dispatch_tokens(x,topk_indices)# Step 4: 每个 Expert 算自己的那部分expert_outputs=[]fori,expertinenumerate(self.experts):iflen(dispatched[i])>0:out=expert(dispatched[i])expert_outputs.append((i,out))# Step 5: 收集 Expert 输出 = All-to-All 反向通信# 把各卡 Expert 的结果收回到对应 Token 位置output=self.collect_outputs(expert_outputs,topk_indices,topk_weights)returnoutputCANN 上的 MoE 通信模式
// DeepSeek-R1 的 MoE 推理——8 卡 Expert 并行的通信模式classMoEInferenceExecutor{// 每张卡部署 256/8 = 32 个 Expertconstintexperts_per_device=32;// 路由结果的 Token 分发——用 HCCL 的 All-to-AllvoidTokenDispatch(int*topk_indices,float*hidden_states,intnum_tokens,intnum_devices){// Step 1: 统计每个 Device 要发多少 Tokenintsend_counts[8]={0};intsend_displs[9]={0};for(intt=0;t<num_tokens;t++){for(intk=0;k<top_k;k++){intexpert_id=topk_indices[t*top_k+k];intdevice_id=expert_id/experts_per_device;send_counts[device_id]++;}}// 算 displacement 用于 scatterfor(intd=1;d<=num_devices;d++){send_displs[d]=send_displs[d-1]+send_counts[d-1];}// Step 2: 用 HCCL 做 All-to-All——NVIDIA 的 AlltoAll 对应// CANN 的 HCCL 支持 HcclAlltoAllV——不等长收发HcclAlltoAllV(hidden_states,send_counts,send_displs,// 发送recv_buffer,recv_counts,recv_displs,// 接收HCCL_FLOAT,num_devices,hccl_comm);}// Step 3: 各卡算完 Expert 后,反向 All-to-All 收回voidTokenCollect(float*expert_output,int*topk_indices,float*topk_weights,float*final_output){// Expert 输出同样 All-to-All 回去HcclAlltoAllV(expert_output,recv_counts,recv_displs,final_buffer,send_counts,send_displs,HCCL_FLOAT,num_devices,hccl_comm);// 按 TopK 权重加权合并——每个 Token 的 8 个 Expert 结果for(intt=0;t<num_tokens;t++){for(intk=0;k<top_k;k++){intidx=topk_indices[t*top_k+k];floatweight=topk_weights[t*top_k+k];// 累加for(intd=0;d<hidden_dim;d++){final_output[t*hidden_dim+d]+=final_buffer[idx*hidden_dim+d]*weight;}}}}};All-to-All 通信是 MoE 推理的瓶颈。DeepSeek-R1 每层 MoE 要做两次 All-to-All(分发+收集),80 层就是 160 次。CANN 的 HCCL 用 NVLink 等价的卡间互联拓扑做 AlltoAllV 优化——让 Token 分布均衡的设备路由不走跨交换机。
PD 分离架构下的 DeepSeek-R1
# DeepSeek-R1 的 PD(Prefill-Decode)分离部署classDeepSeekPDSeparation:""" DeepSeek-R1·671B 推荐用 PD 分离部署: - Prefill 阶段:计算密集型,配较少卡 + 大 Batch - Decode 阶段:访存密集型,配较多卡 + 小 Batch CANN 的 hixl 支持零拷贝的单边通信——Prefill 算完的 KV Cache 直接暴露给 Decode 卡读,不用显式搬运。 """def__init__(self):# Prefill 池:4 张卡,每卡处理 16 个请求的 Prefillself.prefill_pool=[4,"ascend910","prefill"]# Decode 池:16 张卡,每卡处理 8 个请求的 Decodeself.decode_pool=[16,"ascend910","decode"]# hixl 初始化——零拷贝共享内存importhixl self.shared_kv=hixl.SharedMemory(size_per_token=128*64*2*2,# 128heads × 128dim × FP16 × KVnum_devices=20# 4+16)defhandoff(self,request_id):""" Prefill 完成后把 KV Cache 句柄传给 Decode 卡 hixl 用远端内存直接映射——不用走 HCCL 搬运 """kv_handle=self.prefill_pool.export_kv(request_id)# kv_handle 包含:物理地址 + 长度 + 设备 ID# Decode 卡通过 hixl 的 rdma_read 直接读self.decode_pool.import_kv(request_id,kv_handle)# 零拷贝——实际只有一次 PCIe/NVLink 的读取DeepSeek-R1 的 256 Expert × 8 TopK 的稀疏激活特点,让 PD 分离 + MoE All-to-All 成为推理系统设计的关键。CANN 在这一场景的独特优势是 hixl 的单边通信——PD 分离场景下 KV Cache 的零拷贝传输能省掉 30% 的卡间带宽。
参考仓库
DeepSeek-R1 推理配方
MoE 相关 Transformer 算子
hixl 单边通信库
HCCL 集合通信
