当前位置: 首页 > news >正文

Windows10下YOLOv8-Pose(8.2.10)从零部署:自定义数据集训练与工程化推理实战

1. 环境准备与工具安装

在Windows10系统下部署YOLOv8-Pose需要先搭建好开发环境。我建议使用Anaconda来管理Python环境,这样可以避免不同项目之间的依赖冲突。首先下载并安装Anaconda最新版,这个步骤很简单,就像安装普通软件一样一路下一步即可。

装好Anaconda后,打开命令提示符(CMD)或Anaconda Prompt,创建一个新的Python虚拟环境:

conda create -n yolo8 python=3.8 conda activate yolo8

这里我选择Python 3.8是因为它在兼容性方面表现最稳定。激活环境后,我们需要安装PyTorch。这里有个坑要注意:必须安装与你的CUDA版本匹配的PyTorch。你可以通过运行nvidia-smi命令查看CUDA版本。以CUDA 11.6为例:

pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

接下来安装其他必要的依赖包:

pip install numpy opencv-python pillow pandas matplotlib seaborn tqdm wandb seedir -i https://pypi.tuna.tsinghua.edu.cn/simple

最后安装Ultralytics官方库,这是YOLOv8的核心:

pip install ultralytics

2. 数据标注与预处理

2.1 使用LabelMe标注关键点

对于自定义关键点检测任务,我推荐使用LabelMe这个开源工具。它支持矩形框和关键点标注,而且生成的JSON格式很容易转换。安装LabelMe很简单:

pip install labelme labelme

标注时有个小技巧:先标注完所有图片的边界框,然后再统一标注某一类关键点。比如你要检测工业零件的三个角度关键点,可以先把所有图片的30度角点标完,再标60度的,最后标90度的。这样做有两个好处:一是效率高,二是减少标注错误。

标注完成后,每个图片会生成对应的JSON文件。我强烈建议写个脚本检查标注质量:

import os import cv2 import json def visualize_annotations(image_folder): for file in os.listdir(image_folder): if file.endswith('.jpg'): img_path = os.path.join(image_folder, file) json_path = os.path.splitext(img_path)[0] + '.json' img = cv2.imread(img_path) with open(json_path) as f: data = json.load(f) # 绘制标注框和关键点 for shape in data['shapes']: if shape['shape_type'] == 'rectangle': points = shape['points'] cv2.rectangle(img, (int(points[0][0]), int(points[0][1])), (int(points[1][0]), int(points[1][1])), (0,255,0), 2) elif shape['shape_type'] == 'point': point = shape['points'][0] cv2.circle(img, (int(point[0]), int(point[1])), 5, (0,0,255), -1) cv2.imshow('Annotation Check', img) if cv2.waitKey(0) == ord('q'): break

2.2 数据集划分与格式转换

标注完成后,我们需要将数据集划分为训练集和验证集,通常采用8:2的比例。同时要把LabelMe的JSON格式转换为YOLO格式的TXT文件。这里我分享一个完整的处理流程:

import os import json import random from tqdm import tqdm def convert_labelme_to_yolo(json_folder, output_folder, class_map, kpt_classes): os.makedirs(output_folder, exist_ok=True) for json_file in tqdm(os.listdir(json_folder)): if not json_file.endswith('.json'): continue with open(os.path.join(json_folder, json_file)) as f: data = json.load(f) txt_lines = [] img_width = data['imageWidth'] img_height = data['imageHeight'] # 处理每个标注对象 for shape in data['shapes']: if shape['shape_type'] == 'rectangle': # 转换边界框 points = shape['points'] x_center = (points[0][0] + points[1][0]) / 2 / img_width y_center = (points[0][1] + points[1][1]) / 2 / img_height width = abs(points[1][0] - points[0][0]) / img_width height = abs(points[1][1] - points[0][1]) / img_height line = f"{class_map[shape['label']]} {x_center:.5f} {y_center:.5f} {width:.5f} {height:.5f}" # 收集该框内的关键点 kpts = {} for kpt_shape in data['shapes']: if kpt_shape['shape_type'] == 'point' and \ points[0][0] < kpt_shape['points'][0][0] < points[1][0] and \ points[0][1] < kpt_shape['points'][0][1] < points[1][1]: kpts[kpt_shape['label']] = kpt_shape['points'][0] # 按预定顺序添加关键点 for cls in kpt_classes: if cls in kpts: x = kpts[cls][0] / img_width y = kpts[cls][1] / img_height line += f" {x:.5f} {y:.5f} 2" # 2表示可见 else: line += " 0 0 0" # 0表示不存在 txt_lines.append(line) # 保存YOLO格式标签 txt_filename = os.path.splitext(json_file)[0] + '.txt' with open(os.path.join(output_folder, txt_filename), 'w') as f: f.write('\n'.join(txt_lines))

3. 模型训练与调优

3.1 准备配置文件

YOLOv8-Pose需要一个YAML配置文件来定义数据集和模型参数。创建一个custom_pose.yaml文件:

# 数据集路径 path: ./datasets/custom_pose train: images/train val: images/val # 关键点类别 kpt_shape: [3] # 关键点数量 flip_idx: [] # 对称关键点索引,如[1,0]表示第0和第1个关键点对称 # 类别名称 names: 0: object # 关键点名称和连接关系 skeleton: []

3.2 启动训练

训练命令有很多可调参数,这里我分享几个实用的组合:

# 基础训练(使用预训练权重) yolo pose train data=custom_pose.yaml model=yolov8m-pose.pt epochs=100 imgsz=640 batch=16 device=0 # 高级训练(自定义参数) yolo pose train data=custom_pose.yaml model=yolov8n-pose.pt pretrained=True \ epochs=150 batch=8 imgsz=640 optimizer=Adam lr0=0.001 \ pose=12.0 kobj=1.5 cls=0.5 box=7.5 device=0 workers=4

训练过程中有几个关键点需要注意:

  1. 学习率:太大容易震荡,太小收敛慢。可以从0.01开始尝试。
  2. 批量大小:受限于显存,在保证不OOM的情况下尽可能大。
  3. 损失权重:pose控制关键点定位,kobj控制关键点置信度。
  4. 数据增强:默认开启马赛克增强,对小目标很有效。

3.3 训练监控

推荐使用Weights & Biases(WandB)来监控训练过程:

pip install wandb wandb login

训练时会自动记录各项指标,你可以在网页上实时查看损失曲线、验证精度等。

4. 模型部署与推理

4.1 图片推理

训练完成后,可以使用最佳模型进行推理。这里我封装了一个更易用的推理类:

from ultralytics import YOLO import cv2 import numpy as np class PoseDetector: def __init__(self, model_path, kpt_colors=None, kpt_radius=5, line_thickness=2): self.model = YOLO(model_path) self.kpt_colors = kpt_colors or [(255,0,0), (0,255,0), (0,0,255)] self.kpt_radius = kpt_radius self.line_thickness = line_thickness def detect(self, img, conf_threshold=0.5): results = self.model(img, conf=conf_threshold) visualized_img = img.copy() for result in results: boxes = result.boxes.data.tolist() keypoints = result.keypoints.data.cpu().numpy() for box, kpts in zip(boxes, keypoints): # 绘制边界框 x1, y1, x2, y2 = map(int, box[:4]) cv2.rectangle(visualized_img, (x1,y1), (x2,y2), (0,255,0), 2) # 绘制关键点 for i, (x, y, conf) in enumerate(kpts): if conf > conf_threshold: color = self.kpt_colors[i % len(self.kpt_colors)] cv2.circle(visualized_img, (int(x), int(y)), self.kpt_radius, color, -1) return visualized_img, results # 使用示例 detector = PoseDetector('best.pt') img = cv2.imread('test.jpg') result_img, results = detector.detect(img) cv2.imwrite('result.jpg', result_img)

4.2 视频流处理

对于视频或摄像头实时处理,可以使用以下优化后的代码:

import cv2 from pose_detector import PoseDetector # 上面的类 def process_video(input_path, output_path, model_path): cap = cv2.VideoCapture(input_path) fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) detector = PoseDetector(model_path) while cap.isOpened(): ret, frame = cap.read() if not ret: break result_frame, _ = detector.detect(frame) out.write(result_frame) cv2.imshow('Preview', result_frame) if cv2.waitKey(1) == ord('q'): break cap.release() out.release() cv2.destroyAllWindows() # 处理视频文件 process_video('input.mp4', 'output.mp4', 'best.pt') # 处理摄像头 def process_camera(camera_id=0, model_path='best.pt'): cap = cv2.VideoCapture(camera_id) detector = PoseDetector(model_path) while True: ret, frame = cap.read() if not ret: break result_frame, _ = detector.detect(frame) cv2.imshow('Camera', result_frame) if cv2.waitKey(1) == ord('q'): break cap.release() cv2.destroyAllWindows()

5. 工程化优化建议

在实际项目中部署YOLOv8-Pose时,有几个优化方向值得考虑:

  1. 模型量化:使用TensorRT或ONNX Runtime加速推理
yolo export model=best.pt format=onnx opset=12 simplify=True
  1. 多线程处理:使用生产者-消费者模式处理视频流
from queue import Queue from threading import Thread class VideoProcessor: def __init__(self, model_path, buffer_size=10): self.detector = PoseDetector(model_path) self.frame_queue = Queue(maxsize=buffer_size) self.result_queue = Queue(maxsize=buffer_size) def start(self, input_path): self.capture_thread = Thread(target=self._capture_frames, args=(input_path,)) self.process_thread = Thread(target=self._process_frames) self.capture_thread.start() self.process_thread.start() def _capture_frames(self, input_path): cap = cv2.VideoCapture(input_path) while cap.isOpened(): ret, frame = cap.read() if not ret: break self.frame_queue.put(frame) cap.release() def _process_frames(self): while True: frame = self.frame_queue.get() if frame is None: break result = self.detector.detect(frame) self.result_queue.put(result)
  1. 性能监控:添加FPS计数和资源监控
import time import psutil class PerformanceMonitor: def __init__(self): self.start_time = time.time() self.frame_count = 0 self.fps_history = [] def update(self): self.frame_count += 1 if self.frame_count % 10 == 0: elapsed = time.time() - self.start_time fps = self.frame_count / elapsed self.fps_history.append(fps) # 获取内存和CPU使用情况 mem = psutil.virtual_memory().percent cpu = psutil.cpu_percent() print(f"FPS: {fps:.1f} | Memory: {mem}% | CPU: {cpu}%")
  1. 结果后处理:添加关键点滤波和平滑
from collections import deque class KeypointSmoother: def __init__(self, window_size=5): self.window_size = window_size self.history = {} def smooth(self, current_kpts): smoothed = {} for obj_id, kpts in current_kpts.items(): if obj_id not in self.history: self.history[obj_id] = deque(maxlen=self.window_size) self.history[obj_id].append(kpts) # 使用移动平均平滑关键点 smoothed[obj_id] = np.mean(self.history[obj_id], axis=0) return smoothed
http://www.jsqmd.com/news/595005/

相关文章:

  • 3D点云检测实战-Nuscenes数据集解析与Python工具链深度指南
  • Unity HDRP水系统性能避坑指南:从脚本交互到水下渲染,让你的游戏帧率稳如泰山
  • JVM学习-基础篇-垃圾回收
  • OpenClaw浏览器自动化:Qwen3-14B驱动无头爬虫实战
  • 从零开始用JavaScript Canvas画彩虹:理解arc()绘图与颜色渐变
  • HTB——Oopsie
  • Java SpringBoot+Vue3+MyBatis Web在线考试系统系统源码|前后端分离+MySQL数据库
  • 我的CSDN第一篇
  • OpenClaw+千问3.5-35B-A3B-FP8:自动化商品描述生成器
  • TimeGPT新手必看:5分钟搞定token获取与AirPassengers数据集预测实战
  • OpenClaw性能优化:Qwen3-14B镜像的并发请求控制策略
  • Unity2018中SpriteAtlas与AB包的高效集成实践
  • c++如何利用C++23的std--expected重构文件操作的错误管理代码【实战】
  • 自动化数据清洗:OpenClaw调用千问3.5-9B处理混乱CSV文件
  • STM32F103C8T6 RAM不够用?手把手教你用CAN总线实现边收边写的IAP升级(附完整代码)
  • Unity游戏开发:Highlight Plus 8.0在URP渲染管线下的完整配置指南(含常见问题解决)
  • OpenClaw离线模式探索:Qwen3-14b_int4_awq断网环境下的应急方案
  • OpenClaw日志分析自动化:Qwen3-14b_int4_awq模型驱动的问题排查
  • SEO 对于SaaS产品销售有什么影响
  • 电商运营自动化:OpenClaw驱动千问3.5-27B批量生成商品描述
  • TFT_eSPI_Charts嵌入式图表库:轻量级实时可视化方案
  • Agent、Copilot、Advisor
  • 从无人机抗风到机械臂消振:聊聊ESO(扩张状态观测器)在机器人里的那些实战用法
  • 2026年比较好的易打理进口地板/抗菌进口地板稳定供货厂家推荐 - 品牌宣传支持者
  • OpenClaw高阶用法:Qwen3-14B模型的热切换与A/B测试
  • OpenClaw多模型切换指南:百川2-13B-4bits与Qwen3-32B混合调用
  • 基于SpringBoot + Vue的医院患者就诊数据可视化分析系统(角色:患者、医生、管理员)
  • OpenClaw智能旅行规划:千问3.5-35B-A3B-FP8解析景点照片生成个性化行程表
  • OpenClaw浏览器自动化:Qwen3-4B驱动网页检索与内容抓取
  • SQL复杂报表如何通过窗口函数优化_减少子查询提升性能