从标注到上线:手把手教你用HRNet(OCR分支)训练自己的语义分割模型(附TensorRT加速与Triton部署全流程)
从标注到上线:HRNet-OCR语义分割全流程实战指南
在工业质检、遥感影像分析和自动驾驶等场景中,像素级语义分割技术正成为关键基础设施。HRNet(High-Resolution Network)凭借其独特的并行多分辨率特征融合架构,在保持高分辨率特征的同时实现高效计算,特别适合需要精细边界的应用场景。结合OCR(Object-Contextual Representations)模块后,模型能够更好地理解对象间的上下文关系,进一步提升分割精度。
本文将采用"道路裂缝检测"作为示例场景,完整演示从数据标注到生产环境部署的全链路流程。不同于常规教程仅展示基础操作,我们会重点剖析每个环节的工程决策要点,包括数据增强策略选择、类别不平衡处理、TensorRT优化技巧以及Triton推理服务器的性能调优方法。
1. 数据准备与标注工程
1.1 标注工具选型与技巧
Labelme作为开源标注工具虽然简单易用,但在实际工业场景中需要考虑更多细节:
# 安装带多边形编辑增强版的labelme pip install labelme -i https://pypi.tuna.tsinghua.edu.cn/simple标注效率提升技巧:
- 使用快捷键(W创建多边形,A/D切换图片)
- 对相似物体采用"复制标注"功能
- 设置
label_config.json预定义类别颜色
注意:标注时应保持约10%的重叠区域,避免后续数据增强时出现空白边缘
1.2 数据集格式转换实战
Cityscapes格式虽通用,但原始实现需要调整以适应不同场景:
from PIL import Image import numpy as np def convert_label(label_path): """将彩色标签图转换为索引图""" color_map = { (0,0,0): 0, # 背景 (255,0,0): 1, # 裂缝 (0,255,0): 2 # 修补区域 } label = np.array(Image.open(label_path)) index_map = np.zeros(label.shape[:2], dtype=np.uint8) for color, index in color_map.items(): index_map[(label == color).all(axis=-1)] = index return Image.fromarray(index_map)常见问题解决方案:
- 遇到内存不足时,改用生成器逐图处理
- 大尺寸图像建议先resize再标注
- 使用
tqdm库添加进度条监控转换过程
2. HRNet-OCR模型训练精要
2.1 环境配置避坑指南
PyTorch环境配置需特别注意CUDA兼容性:
| 组件 | 推荐版本 | 替代方案 | 注意事项 |
|---|---|---|---|
| CUDA | 11.1 | 10.2 | 需与驱动版本匹配 |
| cuDNN | 8.0.5 | 7.6.5 | 需从NVIDIA官网下载 |
| PyTorch | 1.9.0 | 1.7.1 | 使用conda安装更稳定 |
| TorchVision | 0.10.0 | 0.8.2 | 需与PyTorch版本对应 |
典型问题排查:
# 验证CUDA可用性 python -c "import torch; print(torch.cuda.is_available())" # 检查cuDNN python -c "import torch; print(torch.backends.cudnn.version())"2.2 关键配置参数解析
修改seg_hrnet_ocr_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml时需关注:
TRAIN: IMAGE_SIZE: [512, 512] # 根据显存调整 BASE_SIZE: 2048 # 原始图像短边尺寸 BATCH_SIZE_PER_GPU: 4 # 1080Ti建议设为2-4 CLASS_WEIGHTS: [1.0, 2.0, 1.5] # 类别权重系数 OPTIMIZER: LR: 0.01 # 初始学习率 WD: 0.0005 # 权重衰减 MOMENTUM: 0.9训练监控技巧:
# 启动TensorBoard监控 tensorboard --logdir=output --port=6006 # 常用监控指标 watch -n 0.5 nvidia-smi # GPU利用率监控3. TensorRT加速实战
3.1 模型转换全流程
使用Docker环境保证一致性:
# 基于官方镜像构建定制环境 FROM nvcr.io/nvidia/tensorrt:21.03-py3 RUN apt-get update && apt-get install -y \ libgl1 libgtk2.0-dev \ && rm -rf /var/lib/apt/lists/* WORKDIR /workspace转换关键步骤:
- 生成.wts中间文件:
python tools/gen_wts.py \ --cfg experiments/road_crack.yaml \ --ckpt output/best.pth \ --save_path hrnet_ocr.wts- 编译TensorRT引擎:
./hrnet_ocr -s hrnet_ocr.wts hrnet_ocr.engine 48 # 48表示使用HRNet-W483.2 性能优化技巧
通过trtexec工具进行基准测试:
trtexec --loadEngine=hrnet_ocr.engine \ --shapes=input:1x512x512x3 \ --fp16 \ --verbose优化参数对比:
| 模式 | 延迟(ms) | 显存占用(MB) | 适用场景 |
|---|---|---|---|
| FP32 | 45.2 | 1243 | 高精度要求 |
| FP16 | 28.7 | 867 | 平衡精度与速度 |
| INT8 | 19.4 | 512 | 极致性能需求 |
4. Triton推理服务部署
4.1 服务端配置详解
config.pbtxt关键参数说明:
instance_group { count: 2 # 实例数 kind: KIND_GPU gpus: [0, 1] # 多卡部署 } dynamic_batching { max_queue_delay_microseconds: 1000 preferred_batch_size: [1, 4, 8] } model_warmup { { name: "warmup_sample" batch_size: 1 inputs: { key: "data" value: { data_type: TYPE_FP32 dims: [512, 512, 3] zero_data: true } } } }启动参数优化:
docker run -d --gpus all \ --shm-size=16G \ -p 8000-8002:8000-8002 \ -v /path/to/models:/models \ nvcr.io/nvidia/tritonserver:21.03-py3 \ tritonserver --model-repository=/models \ --http-thread-count=8 \ --grpc-infer-allocation-pool-size=324.2 客户端最佳实践
带批处理的异步客户端实现:
import tritonclient.grpc.aio as grpcclient class TritonInferencer: def __init__(self, url): self.client = grpcclient.InferenceServerClient(url) async def infer_batch(self, image_batch): inputs = [grpcclient.InferInput("data", image_batch.shape, "FP32")] inputs[0].set_data_from_numpy(image_batch) outputs = [grpcclient.InferRequestedOutput("output")] response = await self.client.async_infer( model_name="hrnet_ocr", inputs=inputs, outputs=outputs ) return response.as_numpy("output") # 使用示例 async def process_video(video_path): inferencer = TritonInferencer("localhost:8001") batch = preprocess_frames(frames) # 预处理帧数据 results = await inferencer.infer_batch(batch) postprocess(results)性能监控指标:
- 使用Prometheus收集
nv_gpu_utilization - 通过Triton的
/metrics端点获取QPS - 使用
perf_analyzer进行负载测试
在实际部署道路裂缝检测系统时,我们发现将预处理(resize/normalize)移到客户端可减少约30%的服务器负载。对于1080p视频流,采用FP16模式的Triton实例在T4 GPU上可实现45FPS的实时处理能力。
