当前位置: 首页 > news >正文

深度学习中评估指标计算库TorchMetrics的使用

TorchMetrics是一个包含100多个PyTorch指标实现的集合(如分类、检测、分割、回归等),并提供易于使用的API来创建自定义指标。可以将TorchMetrics与任何PyTorch模型或PyTorch Lightning结合使用。源码地址:https://github.com/Lightning-AI/torchmetrics,最新发布版本为v1.9.0,license为Apache-2.0。

安装完YOLO环境后,执行以下命令:评估指标任务不同,安装命令也不同

pip install torchmetrics pip install torchmetrics[detection]

YOLOv8/YOLO11/YOLO26有自己内置的评估逻辑,一般不建议直接在训练循环内部强行替换其指标计算。这里训练完YOLOv8后评估,只是为了演示TorchMetrics的使用

Classify主要测试代码如下:

def _parse_label_file(label_file): idx_to_class = {} class_to_idx = {} with open(label_file, mode="r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue idx, name = line.split() idx = int(idx) idx_to_class[idx] = name class_to_idx[name] = idx return idx_to_class, class_to_idx def _get_images(images_path): image_files = list(Path(images_path).rglob("*.*")) image_files = [p for p in image_files if p.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".webp"]] if len(image_files) == 0: raise RuntimeError(colorama.Fore.RED + f"no images found: {images_path}") return image_files def test_classify(model_name, images_path, label_file): if model_name is None or not model_name or not Path(model_name).is_file(): raise FileNotFoundError(colorama.Fore.RED + f"{model_name} is not a file") if images_path is None or not images_path or not Path(images_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED + f"{images_path} is not a directory") if label_file is None or not label_file or not Path(label_file).is_file(): raise FileNotFoundError(colorama.Fore.RED + f"{label_file} is not a file") _, class_to_idx = _parse_label_file(label_file) print(f"class to idx: {class_to_idx}") num_classes = len(class_to_idx) acc_metric = MulticlassAccuracy(num_classes=num_classes) f1_metric = MulticlassF1Score(num_classes=num_classes) acc_metric.reset() f1_metric.reset() image_files = _get_images(images_path) model = YOLO(model_name) model.eval() with torch.no_grad(): for img_path in image_files: class_name = img_path.parent.name if class_name not in class_to_idx: print(colorama.Fore.YELLOW + f"invalid image file: {img_path}") continue gt_label = class_to_idx[class_name] results = model(str(img_path), verbose=False) probs = results[0].probs.data pred_label = int(torch.argmax(probs).item()) pred_tensor = torch.tensor([pred_label]) gt_tensor = torch.tensor([gt_label]) acc_metric.update(pred_tensor, gt_tensor) f1_metric.update(pred_tensor, gt_tensor) acc = acc_metric.compute().item() f1 = f1_metric.compute().item() print(colorama.Fore.GREEN + f"Accuracy: {acc:.4f}\nF1 Score: {f1:.4f}")

执行结果如下图所示:

Detect主要测试代码如下:

def test_detect(model_name, images_path, txts_path): if model_name is None or not model_name or not Path(model_name).is_file(): raise FileNotFoundError(colorama.Fore.RED + f"{model_name} is not a file") if images_path is None or not images_path or not Path(images_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED + f"{images_path} is not a directory") if txts_path is None or not txts_path or not Path(txts_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED + f"{txts_path} is not a directory") image_files = _get_images(images_path) preds_all = [] targets_all = [] model = YOLO(model_name) model.eval() with torch.no_grad(): for img_path in image_files: txt_path = txts_path + "/" + img_path.stem + ".txt" if not Path(txt_path).exists(): raise FileNotFoundError(colorama.Fore.RED + f"{txt_path} does not exist") img = cv2.imread(str(img_path)) if img is None: raise FileNotFoundError(colorama.Fore.RED + f"unable to load image file: {img_path}") h, w = img.shape[:2] gt_boxes = [] gt_labels = [] with open(txt_path, mode="r", encoding="utf-8") as f: for line in f: parts = line.strip().split() if len(parts) != 5: raise RuntimeError(colorama.Fore.RED + f"{txt_path}: file content is incorrect") cls = int(parts[0]) cx, cy, bw, bh = map(float, parts[1:]) x1 = (cx - bw / 2) * w y1 = (cy - bh / 2) * h x2 = (cx + bw / 2) * w y2 = (cy + bh / 2) * h gt_boxes.append([x1, y1, x2, y2]) gt_labels.append(cls) if len(gt_boxes) == 0: gt_boxes = torch.zeros((0, 4)) gt_labels = torch.zeros((0,), dtype=torch.int64) else: gt_boxes = torch.tensor(gt_boxes, dtype=torch.float32) gt_labels = torch.tensor(gt_labels, dtype=torch.int64) results = model(str(img_path), verbose=False)[0] if results.boxes is None or len(results.boxes) == 0: pred_boxes = torch.zeros((0, 4)) pred_scores = torch.zeros((0,)) pred_labels = torch.zeros((0,), dtype=torch.int64) else: pred_boxes = results.boxes.xyxy.cpu() pred_scores = results.boxes.conf.cpu() pred_labels = results.boxes.cls.cpu().to(torch.int64) preds_all.append({"boxes": pred_boxes, "scores": pred_scores, "labels": pred_labels}) targets_all.append({"boxes": gt_boxes, "labels": gt_labels}) print(f"total samples: {len(preds_all)}") metric = MeanAveragePrecision(iou_type="bbox", class_metrics=True) metric.update(preds_all, targets_all) result = metric.compute() print(f"metrics result: {result}") map50 = result["map_50"].item() map5095 = result["map"].item() print(colorama.Fore.GREEN + f"mAP50: {map50:.4f}\nmAP50-95: {map5095:.4f}")

执行结果如下图所示:

Segment主要测试代码如下:

def _polygon_to_mask(polygons, h, w): mask = np.zeros((h, w), dtype=np.uint8) for poly in polygons: pts = np.array(poly, dtype=np.int32).reshape(-1, 2) cv2.fillPoly(mask, [pts], 1) return mask def test_segment(model_name, images_path, txts_path): if model_name is None or not model_name or not Path(model_name).is_file(): raise FileNotFoundError(colorama.Fore.RED + f"{model_name} is not a file") if images_path is None or not images_path or not Path(images_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED + f"{images_path} is not a directory") if txts_path is None or not txts_path or not Path(txts_path).is_dir(): raise FileNotFoundError(colorama.Fore.RED + f"{txts_path} is not a directory") image_files = _get_images(images_path) model = YOLO(model_name) num_classes = len(model.names) + 1 # 0:background metric = MeanIoU(num_classes=num_classes, per_class=True, input_format="index") metric.reset() total = 0 target_size = (480, 480) model.eval() with torch.no_grad(): for img_path in image_files: txt_path = txts_path + "/" + img_path.stem + ".txt" if not Path(txt_path).exists(): raise FileNotFoundError(colorama.Fore.RED + f"{txt_path} does not exist") img = cv2.imread(str(img_path)) if img is None: raise FileNotFoundError(colorama.Fore.RED + f"unable to load image file: {img_path}") h, w = img.shape[:2] gt_mask = np.zeros((h, w), dtype=np.uint8) pred_mask = np.zeros((h, w), dtype=np.uint8) with open(txt_path, mode="r", encoding="utf-8") as f: for line in f: parts = list(map(float, line.strip().split())) cls = int(parts[0]) coords = parts[1:] pts = [] for i in range(0, len(coords), 2): x = coords[i] * w y = coords[i + 1] * h pts.append([x, y]) mask = _polygon_to_mask([pts], h, w) gt_mask[mask == 1] = cls + 1 results = model(str(img_path), verbose=False)[0] if results.masks is not None: masks = results.masks.data.cpu().numpy() classes = results.boxes.cls.cpu().numpy().astype(int) for i in range(len(masks)): m = masks[i] cls = classes[i] m = (m > 0.5).astype(np.uint8) m = cv2.resize(m, (w, h), interpolation=cv2.INTER_NEAREST) pred_mask[m == 1] = cls + 1 pred_tensor = torch.tensor(cv2.resize(pred_mask, target_size, interpolation=cv2.INTER_NEAREST)).long() gt_tensor = torch.tensor(cv2.resize(gt_mask, target_size, interpolation=cv2.INTER_NEAREST)).long() metric.update(pred_tensor.unsqueeze(0), gt_tensor.unsqueeze(0)) total += 1 miou_per_class = metric.compute() print(f"metrics result(per class): {miou_per_class}") miou = miou_per_class[1:].mean().item() # remove backgroud print(colorama.Fore.GREEN + f"total samples: {total}\nmIoU: {miou:.4f}")

执行结果如下图所示:

GitHub:https://github.com/fengbingchun/NN_Test

http://www.jsqmd.com/news/706794/

相关文章:

  • AI代码审查实战:让CodeRabbit当你的第二双眼睛
  • 物理信息神经网络驱动的阻变存储器参数反演:从时序电压响应中精准提取二氧化钛ReRAM物理参数(Python)
  • 电脑软件《图片转PDF转换器》 - 新手入门指南
  • Unsloth Sglang Vllm核心区别和使用场景
  • Dubbo线程池策略详解:Fixed、Cached、Limited与Eager对比
  • 2026正规免费量化交易软件推荐榜:ea量化交易软件/什么是量化交易/手机量化交易软件/散户如何做量化交易/期货量化交易系统/选择指南 - 优质品牌商家
  • 循环优化设计
  • 从零开始学C语言:环境搭建与首个代码
  • 梯度下降算法详解:原理、实现与优化技巧
  • 零基础秒落地!魔珐星云打造专属法务数字人
  • 成都地区、H型钢、350X350X12X19、Q235B、包钢、现货批发供应 - 四川盛世钢联营销中心
  • 用户上周说有两个孩子,这周说有三个孩子,Agent 如何处理记忆冲突?
  • Weaviate向量数据库实战:从部署到多模态搜索与生产优化
  • PyTorch训练管理:检查点与早停技术详解
  • 成都地区、H型钢、700X300X13X14、Q235B、包钢、现货批发供应 - 四川盛世钢联营销中心
  • 成都地区、低合金H型钢、500X200X10X16、Q355B、包钢、现货批发供应 - 四川盛世钢联营销中心
  • 记录一次Jenkins构建任务的坑
  • HTML总结
  • 成都地区、H型钢、588X300X12X20、Q235B、包钢、现货批发供应 - 四川盛世钢联营销中心
  • 205套思维工具(转)
  • caj2pdf:3个技巧让知网CAJ文献在Linux上重获新生
  • 2026川渝地区耐火砖技术分享:耐火材料供应厂家/耐火材料厂商/耐火材料厂家/耐火材料哪家好/耐火材料批发/耐火材料报价/选择指南 - 优质品牌商家
  • 为什么你的Dev Container正在悄悄上传源码?揭秘.gitignore之外的5类敏感数据泄漏路径(企业级隔离方案已落地)
  • 共享记忆会毁掉系统 多智能体信息污染的五种典型路径
  • 贝叶斯信念网络:原理、构建与应用实践
  • Linearis:Rust高性能线性代数库的设计、应用与性能调优
  • 2026年4月宜宾家装公司排行:宜宾装修公司哪家好、宜宾装修公司推荐、宜宾装修公司电话、宜宾装饰公司口碑、宜宾装饰公司哪家好选择指南 - 优质品牌商家
  • 神经网络模型容量控制:节点数与层数优化指南
  • cuML通过PyPI安装:GPU数据科学的新突破
  • 魔珐星云打造上海历史大屏数字人