告别命令行:用Python脚本一键调用trtexec,批量转换ONNX到TensorRT Engine
告别命令行:用Python脚本一键调用trtexec,批量转换ONNX到TensorRT Engine
在AI模型部署的日常工作中,模型格式转换往往是耗时又容易出错的环节。每次打开终端,输入一长串trtexec命令参数,稍有不慎就会因拼写错误或参数遗漏导致转换失败。更不用说当需要批量处理数十个ONNX模型时,重复劳动带来的效率低下和人为错误几乎不可避免。
想象一下这样的场景:你刚完成一组YOLOv7模型的训练和导出,现在需要将它们全部转换为TensorRT引擎文件。每个模型需要测试fp16和int8两种精度,还要针对不同batch size生成多个版本。如果手动操作,不仅需要记住各种参数组合,还要处理路径问题、日志记录和错误重试——这简直是一场噩梦。
1. 为什么需要自动化转换工具
TensorRT作为NVIDIA推出的高性能推理引擎,能显著提升模型在NVIDIA GPU上的运行效率。而trtexec则是TensorRT工具包中用于模型转换的瑞士军刀,支持从ONNX到TensorRT引擎的转换。但它的命令行操作方式存在几个明显痛点:
- 参数复杂:常用的
--fp16、--workspace等参数容易遗漏或写错 - 批量处理困难:需要手动为每个模型编写转换命令
- 缺乏错误处理:转换失败时往往需要重新开始
- 日志记录缺失:难以追溯哪些模型转换成功或失败
# 典型trtexec命令示例(手动输入易错) trtexec --onnx=model.onnx --saveEngine=model.trt --fp16 --workspace=4096通过Python脚本封装trtexec,我们可以实现:
- 一键批量转换:自动遍历文件夹中的所有ONNX文件
- 参数集中管理:在配置文件中定义不同精度和batch size组合
- 智能错误处理:自动重试失败的转换任务
- 完整日志记录:记录每个模型的转换状态和耗时
2. 环境准备与trtexec配置
2.1 基础环境检查
在开始编写脚本前,确保你的Windows系统已安装以下组件:
| 组件 | 推荐版本 | 验证命令 |
|---|---|---|
| CUDA | 11.x | nvcc --version |
| cuDNN | 对应CUDA版本 | 检查cudnn64_*.dll |
| TensorRT | 8.x | 检查trtexec.exe |
| Python | 3.8+ | python --version |
提示:建议使用conda创建独立Python环境,避免依赖冲突
2.2 定位trtexec可执行文件
trtexec通常位于TensorRT安装目录的bin文件夹中。我们的脚本需要自动发现其位置:
import os from pathlib import Path def find_trtexec(): # 常见安装路径 search_paths = [ "C:/Program Files/NVIDIA/TensorRT/bin", os.environ.get("TENSORRT_PATH", ""), "C:/TensorRT/bin" ] for path in search_paths: if path and (trtexec := Path(path) / "trtexec.exe").exists(): return str(trtexec) raise FileNotFoundError("trtexec.exe not found in standard locations")3. 核心脚本设计与实现
3.1 配置文件设计
使用YAML文件管理转换参数,支持不同模型的不同配置:
# config.yaml示例 global_params: workspace: 4096 min_shapes: "1,3,224,224" opt_shapes: "8,3,224,224" max_shapes: "32,3,224,224" models: - name: "yolov7" onnx_path: "models/yolov7.onnx" output_dir: "engines/" precisions: ["fp32", "fp16"] batch_sizes: [1, 4, 8] - name: "resnet50" onnx_path: "models/resnet50.onnx" tactics: "-cublasLt,+cublas"3.2 批量转换核心逻辑
import subprocess import yaml from datetime import datetime def convert_onnx_to_trt(config_path): with open(config_path) as f: config = yaml.safe_load(f) trtexec_path = find_trtexec() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = open(f"conversion_{timestamp}.log", "w") for model in config["models"]: for precision in model.get("precisions", ["fp32"]): for batch in model.get("batch_sizes", [1]): output_name = f"{model['name']}_bs{batch}_{precision}.trt" output_path = Path(model["output_dir"]) / output_name cmd = [ trtexec_path, f"--onnx={model['onnx_path']}", f"--saveEngine={output_path}", f"--workspace={config['global_params']['workspace']}", f"--minShapes=input:{config['global_params']['min_shapes']}", f"--optShapes=input:{config['global_params']['opt_shapes']}", f"--maxShapes=input:{config['global_params']['max_shapes']}", f"--{precision}", f"--batch={batch}" ] if "tactics" in model: cmd.append(f"--tacticSources={model['tactics']}") try: result = subprocess.run( cmd, check=True, capture_output=True, text=True ) log_file.write(f"SUCCESS: {output_name}\n") except subprocess.CalledProcessError as e: log_file.write(f"FAILED: {output_name}\nError: {e.stderr}\n") log_file.close()4. 高级功能扩展
4.1 动态形状支持
对于需要动态batch size的模型,可以通过形状参数灵活控制:
# 在配置文件中定义形状范围 dynamic_shapes: input1: min: "1,3,224,224" opt: "8,3,224,224" max: "32,3,224,224" input2: min: "1,1,100" opt: "8,1,100" max: "32,1,100"4.2 性能分析与优化
转换完成后自动运行基准测试:
def benchmark_engine(engine_path): cmd = [ trtexec_path, f"--loadEngine={engine_path}", "--duration=10", "--useSpinWait" ] result = subprocess.run(cmd, capture_output=True, text=True) # 解析输出获取吞吐量和延迟数据 throughput = parse_throughput(result.stdout) latency = parse_latency(result.stdout) return { "throughput": throughput, "latency": latency }5. 错误处理与调试技巧
5.1 常见错误解决方案
| 错误类型 | 表现 | 解决方案 |
|---|---|---|
| CUBLAS错误 | CublasLtWrapper::setupHeuristic | 添加--tacticSources=-cublasLt,+cublas |
| 内存不足 | out of memory | 减小--workspace值 |
| 形状不匹配 | input dimensions mismatch | 检查ONNX模型的输入形状 |
5.2 日志分析工具
编写日志解析脚本,快速统计转换成功率:
def analyze_logs(log_file): with open(log_file) as f: lines = f.readlines() success = sum(1 for line in lines if "SUCCESS" in line) failed = sum(1 for line in lines if "FAILED" in line) print(f"转换成功率: {success/(success+failed):.1%}") print("失败模型列表:") print("\n".join(line.split(":")[1] for line in lines if "FAILED" in line))在实际项目中,这套自动化转换工具将模型部署的准备时间从数小时缩短到几分钟。特别是在需要频繁迭代模型版本的场景下,只需更新配置文件即可触发批量转换,彻底告别了手动输入命令的低效工作方式。
