CANN 模型转换与适配:从 PyTorch 到 Ascend OM 的完整指南
模型转换是昇腾落地的第一道坎。不管你用 PyTorch、TensorFlow 还是 MindSpore,最终都要变成 Ascend 的.om模型才能在 NPU 上跑。
这篇文章讲清楚:模型转换的完整流程、常见问题和优化技巧。
为什么需要模型转换?
昇腾 NPU 不能直接运行 PyTorch 的.pt模型。原因有两个:
- 硬件指令集不同:PyTorch 编译成的是 CUDA 指令,昇腾用的是达芬奇架构的指令
- 运行时不同:PyTorch 用的是 CUDA 运行时,昇腾用的是 AscendCL 运行时
所以要把模型"翻译"成昇腾能认识的形式。
模型转换的三条路
路径 1:PyTorch → ONNX → ATC → OM(最常用) 路径 2:PyTorch → TorchScript → ATC → OM 路径 3:TensorFlow/Paddle → ATC → OM推荐路径 1:PyTorch → ONNX → ATC → OM。这是官方推荐的方式,兼容性最好。
路径 1:PyTorch → ONNX → ATC → OM
这是最常用的路径,分两步完成。
步骤 1:PyTorch → ONNX
importtorchimporttorch.nnasnn# 定义一个简单的 Transformer 模型classSimpleTransformer(nn.Module):def__init__(self,vocab_size=50000,hidden_dim=768,num_heads=12):super().__init__()self.embedding=nn.Embedding(vocab_size,hidden_dim)self.attention=nn.MultiheadAttention(hidden_dim,num_heads,batch_first=True)self.fc=nn.Linear(hidden_dim,vocab_size)defforward(self,input_ids,attention_mask=None):x=self.embedding(input_ids)attn_out,_=self.attention(x,x,x,attn_mask=attention_mask)logits=self.fc(attn_out)returnlogits# 实例化模型model=SimpleTransformer()model.eval()# 导出 ONNXdummy_input=torch.randint(0,50000,(1,512))torch.onnx.export(model,dummy_input,"transformer.onnx",input_names=["input_ids","attention_mask"],output_names=["logits"],dynamic_axes={"input_ids":{0:"batch",1:"seq_len"},"attention_mask":{0:"batch",1:"seq_len"},"logits":{0:"batch",1:"seq_len",2:"vocab"}},opset_version=14,do_constant_folding=True)步骤 2:ONNX → OM(使用 ATC 编译器)
# 基础转换命令atc--model=transformer.onnx\--output=transformer\--framework=5\--soc_version=Ascend910\--input_shape="input_ids:[1,512]"\--input_shape="attention_mask:[1,512]"\--log=infoATC 核心参数详解
| 参数 | 说明 | 常见值 |
|---|---|---|
--model | 输入模型路径 | model.onnx |
--output | 输出模型路径(不含扩展名) | model |
--framework | 输入框架类型 | 5=ONNX, 3=TensorFlow, 0=Caffe |
--soc_version | 目标芯片 | Ascend910,Ascend310 |
--input_shape | 输入张量形状 | input_ids:[1,512] |
--precision_mode | 精度模式 | allow_fp16,force_fp16,allow_mixed_precision |
--dynamic_batch | 动态 batch | 1,2,4,8 |
--dynamic_dims | 动态维度 | 16,32,64 |
动态 batch 示例
# 支持 batch=1,2,4,8atc--model=transformer.onnx\--output=transformer\--framework=5\--soc_version=Ascend910\--input_shape="input_ids:[1,512]"\--input_shape="attention_mask:[1,512]"\--dynamic_batch="1,2,4,8"\--log=info动态序列长度示例
# 支持 seq_len=16,32,64,128,256,512atc--model=transformer.onnx\--output=transformer\--framework=5\--soc_version=Ascend910\--input_shape="input_ids:[1,512]"\--input_shape="attention_mask:[1,512]"\--dynamic_dims="16,32,64,128,256,512"\--log=info常见转换问题与解决方案
问题 1:动态算子不支持
# 错误:ONNX 导出生成了动态输出形状# 现象:ATC 报错 "Input shape not fully specified"# 解决 1:在导出时指定静态形状dummy_input=torch.randint(0,50000,(1,512))# 不要让 shape 变成动态的# 解决 2:使用 opset_version=13+ 并指定动态轴torch.onnx.export(model,dummy_input,"model.onnx",dynamic_axes={"input_ids":{1:"seq_len"}})# 然后在 ATC 中指定 --dynamic_dims问题 2:算子不被支持
# 现象:ATC 报错 "Not supported operator: xxx"# 原因:这个算子在 CANN 中没有实现# 解决 1:替换成 CANN 支持的算子# 比如把 torch.nn.GELU 换成自定义的 GELU 算子# 解决 2:使用 ASCF(Ascend Common Framework)自定义算子# 参考:https://atomgit.com/cann/ascf# 解决 3:分模块转换classModelWithCustomOp(nn.Module):def__init__(self):super().__init__()self.encoder=Encoder()# 能转换的部分self.custom_op=CustomOp()# 不能转换的部分defforward(self,x):x=self.encoder(x)x=self.custom_op(x)# 这部分单独处理returnx# 分别转换能转换的部分问题 3:精度下降
# 现象:转换后模型精度下降# 解决 1:使用混合精度atc--model=model.onnx \--output=model \--framework=5\--soc_version=Ascend910 \--precision_mode=allow_mixed_precision# 解决 2:强制 FP32atc--model=model.onnx \--output=model \--framework=5\--soc_version=Ascend910 \--precision_mode=force_fp16# 解决 3:开启算子级精度配置# 在模型代码中指定某些算子用 FP32classModel(nn.Module):@torch.amp.autocast(device_type='npu',dtype=torch.float32)defforward(self,x):returnself.layer_norm(x)问题 4:内存溢出
# 现象:ATC 转换过程中 OOM# 解决 1:减小 batch size--input_shape="input_ids:[1,512]"# 解决 2:开启模型优化atc--model=model.onnx\--output=model\--framework=5\--soc_version=Ascend910\--buffer_optimize=optimize_for_memory# 解决 3:使用图层融合atc--model=model.onnx\--output=model\--framework=5\--soc_version=Ascend910\--fusion_switch_file=fusion_switch.cfg进阶:自定义算子转换
如果模型中有 CANN 不支持的算子,需要自定义算子然后注册到 ATC。
步骤 1:编写 Ascend C 算子
// custom_gelu.cpp#include"acl/acl.h"extern"C"aclStatusCustomGeluCompute(void*inputs[],void*outputs[]){half*input=(half*)inputs[0];half*output=(half*)outputs[0];int32_tlength=512;// 实际从 shape 获取for(inti=0;i<length;i++){floatx=(float)input[i];floatx3=x*x*x;floatt=tanh(0.7978845608f*(x+0.044715f*x3));output[i]=(half)(0.5f*x*(1.0f+t));}returnACL_SUCCESS;}步骤 2:编译算子
ascendc-ocustom_gelu.o-ccustom_gelu.cpp-targetai_core@ascend910 ld-olibcustom_gelu.so custom_gelu.o -L${ASCEND_TOOLKIT_HOME}/lib -lstdc++-lm步骤 3:注册算子
# 在模型转换时指定自定义算子路径atc--model=model.onnx \--output=model \--framework=5\--soc_version=Ascend910 \--op_select_implmode=high_performance \--optypelist_for_implmode=CustomGelu:CustomGeluProc \--customop_dynamic_batch_strategy=1\--insert_op_conf=custom_op.cfg模型验证
转换完成后,验证模型正确性:
importnumpyasnpimportacl# 初始化 ACLacl.init()device_id=0acl.rt.set_device(device_id)# 加载 OM 模型model_id=acl.mdl.load_from_file("transformer.om")# 准备输入input_data=np.random.randint(0,50000,(1,512)).astype(np.int32)input_buffer=acl.util.numpy_to_vec(input_data)# 执行推理outputs=acl.mdl.execute(model_id,[input_buffer])# 验证输出print(outputs[0].shape)print(outputs[0])完整示例:DeepSeek 模型转换
# deepseek_convert.pyimporttorchfromtransformersimportDeepSeekForCausalLM# 1. 加载 PyTorch 模型print("Loading PyTorch model...")model=DeepSeekForCausalLM.from_pretrained("deepseek-ai/DeepSeek-7B")model.eval()# 2. 导出 ONNXprint("Exporting to ONNX...")dummy_input=torch.randint(0,32000,(1,2048))torch.onnx.export(model,dummy_input,"deepseek7b.onnx",input_names=["input_ids"],output_names=["logits"],dynamic_axes={"input_ids":{0:"batch",1:"seq_len"}},opset_version=14,do_constant_folding=True)print("ONNX export done!")# 3. 转换 OMatc--model=deepseek7b.onnx\--output=deepseek7b\--framework=5\--soc_version=Ascend910\--input_shape="input_ids:[1,2048]"\--dynamic_batch="1,2,4,8"\--precision_mode=allow_mixed_precision\--buffer_optimize=optimize_for_memory\--log=infoecho"OM conversion done! Output: deepseek7b.om"相关资料
- cann-recipes-infer:推理配方,含模型转换示例 → https://atomgit.com/cann/cann-recipes-infer
- cann-samples:算子样例,含自定义算子 → https://atomgit.com/cann/cann-samples
- asc-devkit:Ascend C 开发 → https://atomgit.com/cann/asc-devkit
- cann-learning-hub:学习中心 → https://atomgit.com/cann/cann-learning-hub
