05_yolox_s的后处理截断并导出onnx
目的是得到以下模型:
1、
官方yolox_s的源码和yolox_s.pth获取
https://github.com/Megvii-BaseDetection/YOLOX
2、
修改yolo_head.py的forward,替换为以下
defforward(self,xin,labels=None,imgs=None):outputs=[]fork,(cls_conv,reg_conv,stride_this_level,x)inenumerate(zip(self.cls_convs,self.reg_convs,self.strides,xin)):x=self.stems[k](x)cls_feat=cls_conv(x)reg_feat=reg_conv(x)cls_output=self.cls_preds[k](cls_feat)# [B, C, H, W]reg_output=self.reg_preds[k](reg_feat)# [B, 4, H, W]obj_output=self.obj_preds[k](reg_feat)# [B, 1, H, W]# 🚨 关键:不要 decode,不要 concatoutputs.append(reg_output)outputs.append(obj_output)outputs.append(cls_output)returnoutputs3、
修改export_onnx.py的main()为以下
defmain():args=make_parser().parse_args()logger.info("args value: {}".format(args))exp=get_exp(args.exp_file,args.name)exp.merge(args.opts)ifnotargs.experiment_name:args.experiment_name=exp.exp_name model=exp.get_model()ifargs.ckptisNone:file_name=os.path.join(exp.output_dir,args.experiment_name)ckpt_file=os.path.join(file_name,"best_ckpt.pth")else:ckpt_file=args.ckpt# load the model state dictckpt=torch.load(ckpt_file,map_location="cpu")model.eval()if"model"inckpt:ckpt=ckpt["model"]model.load_state_dict(ckpt)model=replace_module(model,nn.SiLU,SiLU)model.head.decode_in_inference=Falselogger.info("loading checkpoint done.")dummy_input=torch.randn(args.batch_size,3,exp.test_size[0],exp.test_size[1])output_names=[]output_names=["reg1","obj1","cls1","reg2","obj2","cls2","reg3","obj3","cls3",]torch.onnx._export(model,dummy_input,args.output_name,input_names=[args.input],output_names=output_names,dynamic_axes={args.input:{0:'batch'},**{name:{0:'batch'}fornameinoutput_names}}ifargs.dynamicelseNone,opset_version=args.opset,)logger.info("generated onnx model named {}".format(args.output_name))ifnotargs.no_onnxsim:importonnxfromonnxsimimportsimplify# use onnx-simplifier to reduce reduent model.onnx_model=onnx.load(args.output_name)model_simp,check=simplify(onnx_model)assertcheck,"Simplified ONNX model could not be validated"onnx.save(model_simp,args.output_name)logger.info("generated simplified onnx model named {}".format(args.output_name))4、
导出指令
python tools/export_onnx.py-fexps/default/yolox_s.py-cyolox_s.pth --output-name yolox_s.onnx--opset12--output.上述完成就可得到需要的onnx
