一、ONNX 与 CANN 的关系
1.1 模型流转路径
PyTorch/TensorFlow ↓ (export) ONNX 模型 ↓ (ATC 转换) CANN .om 模型 ↓ (ACL 推理) 昇腾 NPU 执行 ONNX 是中间格式,ATC 是桥梁
1.2 为什么需要了解兼容性
常见痛点: 1. PyTorch 新算子 ONNX 不支持 2. ONNX 支持但 ATC 不支持 3. 算子行为不一致 (数值差异) 4. 动态 shape 处理差异 5. 自定义算子需要手动注册
二、算子映射表
2.1 CANN 支持的 ONNX 算子
# CANN 支持的 ONNX 算子分类SUPPORTED_OPS={# 算术运算'Add':'_supported','Sub':'supported','Mul':'supported','Div':'supported','Pow':'supported','Sqrt':'supported','Exp':'supported','Log':'supported',# 矩阵运算'MatMul':'supported','Gemm':'supported','MatMulInteger':'supported',# 卷积'Conv':'supported','ConvTranspose':'supported',# 池化'MaxPool':'supported','AveragePool':'supported','GlobalAveragePool':'supported','GlobalMaxPool':'supported',# 激活函数'Relu':'supported','LeakyRelu':'supported','PRelu':'supported','Sigmoid':'supported','Tanh':'supported','Softmax':'supported','Gelu':'supported','Selu':'supported','Elu':'supported',# 归一化'BatchNormalization':'supported','LayerNormalization':'supported','InstanceNormalization':'supported','GroupNormalization':'supported',# Reshape'Reshape':'supported','Flatten':'supported','Squeeze':'supported','Unsqueeze':'supported','Transpose':'supported','Concat':'supported','Split':'supported','Gather':'supported','Slice':'supported',# 注意力'Attention':'partial',# 需要特定格式'MultiHeadAttention':'partial',# 不支持的算子'TopK':'not_supported','NonZero':'not_supported','ScatterND':'not_supported','Upsample':'not_supported',}
2.2 算子行为差异
# 关键行为差异BEHAVIOR_DIFFERENCES={'Conv':{'onnx':'支持 auto_pad','cann':'部分支持,需指定 pad 值','workaround':'手动计算 pad 值'},'Reshape':{'onnx':'支持 -1 推断','cann':'支持','workaround':'无'},'Softmax':{'onnx':'axis 参数默认 -1','cann':'axis 需显式指定','workaround':'显式指定 axis=-1'},'BatchNormalization':{'onnx':'支持 spatial 模式','cann':'仅支持 spatial=1','workaround':'确保 spatial=1'},'Slice':{'onnx':'支持动态 ends/axes','cann':'需静态确定','workaround':'使用 ONNX Simplifier'},}
三、常见转换报错排查
3.1 报错分类与解决方案
COMMON_ERRORS={# 错误类型 1: 不支持的算子'E10001':{'message':'Operator xxx is not supported','cause':'ATC 不支持该 ONNX 算子','solutions':['检查是否有等价的替代算子','使用 ONNX Simplifier 简化模型','自定义算子实现','使用 PyTorch 重新导出(更换算子)']},# 错误类型 2: 算子属性不支持'E10002':{'message':'Attribute xxx of operator xxx is not supported','cause':'算子属性 ATC 不支持','solutions':['修改模型使用支持的属性值','拆分算子为多个支持的算子组合','使用自定义算子']},# 错误类型 3: Shape 不兼容'E10003':{'message':'Shape inference failed for operator xxx','cause':'Shape 推断失败','solutions':['使用 ONNX Simplifier 固定 shape','检查动态 shape 配置','使用 input_shape_range']},# 错误类型 4: 数据类型不支持'E10004':{'message':'Data type xxx is not supported','cause':'数据类型不支持','solutions':['转换为 FP32/FP16/INT8','检查导出时的 dtype 设置','使用 onnxconverter-common 转换']},# 错误类型 5: 内存不足'E10005':{'message':'Memory allocation failed','cause':'转换时内存不足','solutions':['减小 batch size','使用量化','简化模型结构']},}
3.2 自动排查工具
importonnxfromonnximportshape_inferenceclassONNXCompatibilityChecker:def__init__(self,model_path):self.model=onnx.load(model_path)self.model=shape_inference.infer_shapes(self.model)defcheck(self):"""检查 ONNX 模型兼容性"""issues=[]# 1. 检查不支持的算子fornodeinself.model.graph.node:op_type=node.op_typeifop_typenotinSUPPORTED_OPS:issues.append({'type':'unsupported_op','node':node.name,'op':op_type,'severity':'error'})# 2. 检查数据类型forinputinself.model.graph.input:dtype=input.type.tensor_type.elem_typeifdtypenotin[1,7,10]:# FP32, FP16, INT8issues.append({'type':'unsupported_dtype','name':input.name,'dtype':dtype,'severity':'warning'})# 3. 检查动态 shapeforinputinself.model.graph.input:shape=input.type.tensor_type.shapefordiminshape.dim:ifdim.HasField('dim_param'):issues.append({'type':'dynamic_shape','name':input.name,'dim':dim.dim_param,'severity':'info'})returnissuesdefreport(self):"""生成检查报告"""issues=self.check()errors=[iforiinissuesifi['severity']=='error']warnings=[iforiinissuesifi['severity']=='warning']infos=[iforiinissuesifi['severity']=='info']print(f"检查完成:{len(errors)}个错误,{len(warnings)}个警告,{len(infos)}个提示")forissueinerrors:print(f" ❌{issue['type']}:{issue.get('node',issue.get('name'))}-{issue.get('op',issue.get('dtype',''))}")forissueinwarnings:print(f" ⚠️{issue['type']}:{issue.get('name')}")forissueininfos:print(f" ℹ️{issue['type']}:{issue.get('name')}-{issue.get('dim','')}")returnlen(errors)==0# 使用示例checker=ONNXCompatibilityChecker("model.onnx")is_compatible=checker.report()
四、ONNX 模型简化
4.1 使用 ONNX Simplifier
importonnxfromonnxsimimportsimplifydefsimplify_onnx(input_path,output_path):"""简化 ONNX 模型"""model=onnx.load(input_path)# 简化模型model_simp,check=simplify(model,dynamic_input_shape=True,input_shapes={'input':[1,3,224,224]})assertcheck,"简化后的模型验证失败"onnx.save(model_simp,output_path)print(f"模型已简化:{input_path}→{output_path}")# 使用示例simplify_onnx("model.onnx","model_simplified.onnx")
4.2 ONNX 优化
importonnxfromonnximportoptimizerdefoptimize_onnx(input_path,output_path):"""优化 ONNX 模型"""model=onnx.load(input_path)# 优化 pass 列表passes=['eliminate_identity','eliminate_nop_transpose','fuse_consecutive_transposes','fuse_bn_into_conv','fuse_add_bias_into_conv','fuse_matmul_add_bias_into_gemm',]optimized_model=optimizer.optimize(model,passes)onnx.save(optimized_model,output_path)print(f"模型已优化:{output_path}")optimize_onnx("model.onnx","model_optimized.onnx")
五、自定义算子注册
5.1 ATC 自定义算子
// custom_op.cpp#include"register/op_impl_registry.h"// 算子注册classCustomOp:publicops::OpDef{public:CustomOp():ops::OpDef("CustomOp"){}// 输入定义voidInputs(conststd::vector<ge::TensorDesc>&inputs)override{// 定义输入 tensor}// 输出定义voidOutputs(conststd::vector<ge::TensorDesc>&outputs)override{// 定义输出 tensor}// 属性定义voidAttr(conststd::string&name,constge::AnyValue&value)override{// 定义算子属性}// 计算实现ge::StatusCompute(ge::op::ComputeContext&context,conststd::vector<ge::Tensor*>&inputs,std::vector<ge::Tensor*>&outputs)override{// 获取输入ge::Tensor*input=inputs[0];ge::Tensor*output=outputs[0];// 计算逻辑// ...returnge::GRAPH_SUCCESS;}};// 注册算子REGISTER_OP(CustomOp).INPUT(0,ge::DT_FLOAT).OUTPUT(0,ge::DT_FLOAT).ATTR("alpha",ge::ATTR_TYPE_FLOAT,1.0);
5.2 ONNX 自定义算子映射
importonnxdefregister_custom_op_mapping(onnx_op_name,cann_op_name):"""注册 ONNX 算子到 CANN 算子的映射"""# 在 ATC 转换时使用cmd=f""" atc --model=model.onnx \ --framework=5 \ --output=model \ --input_shape="input:1,3,224,224" \ --soc_version=Ascend310 \ --op_mapping="{onnx_op_name}:{cann_op_name}" """returncmd# 使用示例# 将 ONNX 的 CustomRelu 映射到 CANN 的 Relucmd=register_custom_op_mapping("CustomRelu","Relu")
六、动态算子处理
6.1 动态 Shape ONNX 转换
defconvert_dynamic_onnx(model_path,output_path,input_shapes):"""转换动态 ONNX 模型"""importsubprocess# 构建 input_shape_rangeinput_shape_ranges=[]forname,(min_shape,opt_shape,max_shape)ininput_shapes.items():min_str=','.join(map(str,min_shape))opt_str=','.join(map(str,opt_shape))max_str=','.join(map(str,max_shape))input_shape_ranges.append(f"{name}:{min_str}~{opt_str}~{max_str}")# 构建命令cmd=["atc","--model",model_path,"--framework","5","--output",output_path,"--soc_version","Ascend310","--input_shape_range",';'.join(input_shape_ranges)]result=subprocess.run(cmd,capture_output=True,text=True)ifresult.returncode!=0:print(f"转换失败:{result.stderr}")else:print(f"转换成功:{output_path}")# 使用示例convert_dynamic_onnx("model.onnx","model_dynamic",{"input_ids":([1,1],[1,128],[32,512]),"attention_mask":([1,1],[1,128],[32,512])})
七、常见问题速查表
| 报错关键词 | 原因 | 快速解决 |
|---|
not supported | 算子不支持 | ONNX Simplifier 或更换算子 |
attribute not supported | 属性不支持 | 修改模型使用支持的属性 |
shape inference failed | Shape 推断失败 | 固定 shape 或用 input_shape_range |
data type not supported | 数据类型不支持 | 转换为 FP32/FP16 |
memory allocation failed | 内存不足 | 减小 batch 或量化 |
input output mismatch | 输入输出不匹配 | 检查 input_shape 配置 |
graph verify failed | 图验证失败 | 用 Netron 检查模型结构 |
op type not registered | 算子未注册 | 自定义算子或替换 |
相关仓库
- onnx- ONNX 格式规范 https://gitee.com/onnx/onnx
- onnxsim- ONNX 简化工具 https://github.com/onnx/onnx-simplifier
- onnxruntime- ONNX 推理运行时 https://github.com/microsoft/onnxruntime
- onnxconverter-common- ONNX 转换通用工具 https://github.com/microsoft/onnxconverter-common
- torch.onnx- PyTorch ONNX 导出 https://gitee.com/pytorch/pytorch
- tf2onnx- TensorFlow 转 ONNX https://github.com/onnx/tensorflow-onnx
- atc- ATC 转换工具 https://gitee.com/ascend/atc
- ascend-cl- ACL 接口 https://gitee.com/ascend/ascend-cl
- Netron- 模型可视化 https://netron.app