P2PNet训练数据预处理实战:用Python脚本快速生成ShanghaiTech等数据集的train.list
P2PNet训练数据预处理实战:用Python脚本快速生成ShanghaiTech等数据集的train.list
当你第一次拿到ShanghaiTech这样的标准人群计数数据集时,可能会被其原始的文件夹结构弄得一头雾水。images文件夹里散落着数百张图片,GT_txt目录下是对应的标注文件,但P2PNet等现代计数模型训练时需要的却是一个简单的train.list文件——这个文件应该包含每张图片路径及其对应标注文件的路径。本文将带你深入数据预处理的细节,用Python脚本高效解决这个看似简单却容易踩坑的问题。
1. 理解数据集目录结构
在开始编写脚本之前,我们需要先理清楚ShanghaiTech数据集的典型目录结构。以part_A_final为例,其原始解压后的目录通常如下:
part_A_final/ ├── test_data/ │ ├── images/ │ │ ├── IMG_1.jpg │ │ ├── IMG_2.jpg │ │ └── ... │ └── GT_txt/ │ ├── GT_IMG_1.txt │ ├── GT_IMG_2.txt │ └── ... └── train_data/ ├── images/ └── GT_txt/每个.txt标注文件的内容是每行一个点的坐标(x,y),表示图像中一个人的头部位置。P2PNet等模型训练时需要的是将这些图片和对应标注文件的路径配对,生成如下格式的train.list:
/path/to/part_A_final/train_data/images/IMG_1.jpg /path/to/part_A_final/train_data/GT_txt/GT_IMG_1.txt /path/to/part_A_final/train_data/images/IMG_2.jpg /path/to/part_A_final/train_data/GT_txt/GT_IMG_2.txt ...2. 基础Python脚本实现
下面是一个健壮的Python脚本实现,它包含了必要的错误检查和日志输出:
import os from tqdm import tqdm # 用于显示进度条 def generate_dataset_list(dataset_path, output_file): """ 生成P2PNet训练所需的文件列表 参数: dataset_path: 数据集根目录(包含images和GT_txt文件夹) output_file: 输出的list文件路径 """ image_dir = os.path.join(dataset_path, 'images') gt_dir = os.path.join(dataset_path, 'GT_txt') # 检查目录是否存在 if not os.path.exists(image_dir): raise FileNotFoundError(f"图片目录不存在: {image_dir}") if not os.path.exists(gt_dir): raise FileNotFoundError(f"标注目录不存在: {gt_dir}") # 获取所有jpg图片文件 image_files = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')] skipped_pairs = 0 with open(output_file, 'w') as f_out, tqdm(total=len(image_files), desc="处理进度") as pbar: for img_file in image_files: base_name = os.path.splitext(img_file)[0] gt_file = f"GT_{base_name}.txt" img_path = os.path.join(image_dir, img_file) gt_path = os.path.join(gt_dir, gt_file) # 检查文件对是否存在 if os.path.exists(gt_path): f_out.write(f"{img_path} {gt_path}\n") else: skipped_pairs += 1 print(f"警告: 标注文件缺失 - {gt_path}") pbar.update(1) print(f"\n处理完成! 成功配对 {len(image_files)-skipped_pairs} 个文件对, 跳过 {skipped_pairs} 个.") # 使用示例 if __name__ == "__main__": dataset_root = "/path/to/part_A_final/train_data" output_list = "./train.list" generate_dataset_list(dataset_root, output_list)这个脚本相比原始版本有几个重要改进:
- 增加了目录存在性检查,避免因路径错误导致脚本无声失败
- 使用tqdm添加了进度条,处理大型数据集时更友好
- 记录了跳过的文件对数,便于后期检查
- 添加了详细的错误提示信息
3. 脚本优化与高级功能
基础版本已经可以工作,但在实际项目中我们还可以进一步优化:
3.1 支持多种数据集格式
不同的人群计数数据集可能有不同的目录结构和命名约定。我们可以扩展脚本以支持多种常见格式:
def generate_dataset_list_advanced(dataset_path, output_file, dataset_type='ShanghaiTech'): """ 支持多种数据集格式的增强版生成器 参数: dataset_type: 数据集类型('ShanghaiTech', 'UCF-QNRF', 'NWPU') """ # 不同数据集的目录结构和命名约定 dataset_config = { 'ShanghaiTech': { 'image_dir': 'images', 'gt_dir': 'GT_txt', 'gt_prefix': 'GT_' }, 'UCF-QNRF': { 'image_dir': 'images', 'gt_dir': 'gt', 'gt_prefix': 'GT_' }, 'NWPU': { 'image_dir': 'images', 'gt_dir': 'ground_truth', 'gt_prefix': '' } } config = dataset_config.get(dataset_type) if not config: raise ValueError(f"不支持的数据集类型: {dataset_type}") image_dir = os.path.join(dataset_path, config['image_dir']) gt_dir = os.path.join(dataset_path, config['gt_dir']) # 其余代码与基础版本类似,只需修改gt_file的生成方式 gt_file = f"{config['gt_prefix']}{base_name}.txt"3.2 并行处理加速
对于包含数万张图片的大型数据集(如UCF-QNRF),我们可以使用多进程加速处理:
from multiprocessing import Pool def process_single_file(args): """处理单个文件对的辅助函数""" img_file, image_dir, gt_dir, config = args base_name = os.path.splitext(img_file)[0] gt_file = f"{config['gt_prefix']}{base_name}.txt" img_path = os.path.join(image_dir, img_file) gt_path = os.path.join(gt_dir, gt_file) if os.path.exists(gt_path): return f"{img_path} {gt_path}\n" return None def generate_dataset_list_parallel(dataset_path, output_file, dataset_type='ShanghaiTech', workers=4): """并行版本的数据集列表生成器""" config = get_dataset_config(dataset_type) # 假设已定义 image_dir = os.path.join(dataset_path, config['image_dir']) gt_dir = os.path.join(dataset_path, config['gt_dir']) image_files = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')] with Pool(workers) as pool, open(output_file, 'w') as f_out: args = [(img, image_dir, gt_dir, config) for img in image_files] for result in tqdm(pool.imap(process_single_file, args), total=len(image_files)): if result: f_out.write(result)3.3 验证数据集完整性
生成列表后,我们可以添加一个验证步骤确保所有文件都有效:
def validate_dataset_list(list_file): """验证生成的list文件中的所有文件对是否有效""" missing_files = 0 corrupted_files = 0 with open(list_file, 'r') as f: lines = f.readlines() for line in tqdm(lines, desc="验证文件"): img_path, gt_path = line.strip().split() # 检查文件是否存在 if not os.path.exists(img_path): print(f"图片文件缺失: {img_path}") missing_files += 1 continue if not os.path.exists(gt_path): print(f"标注文件缺失: {gt_path}") missing_files += 1 continue # 检查标注文件是否为空 if os.path.getsize(gt_path) == 0: print(f"空标注文件: {gt_path}") corrupted_files += 1 print(f"\n验证完成! 共检查 {len(lines)} 个文件对") print(f"缺失文件: {missing_files}") print(f"损坏文件: {corrupted_files}") return missing_files == 0 and corrupted_files == 04. 实际应用中的常见问题与解决方案
4.1 路径问题
在不同操作系统上运行脚本时可能会遇到路径分隔符问题。我们可以使用os.path模块来处理跨平台兼容性:
# 不推荐 (Windows上会使用反斜杠) image_path = dataset_path + '/images/' + image_file # 推荐 (跨平台兼容) image_path = os.path.join(dataset_path, 'images', image_file)4.2 文件名大小写问题
在Linux和macOS上,文件名是大小写敏感的,而Windows不敏感。为避免问题:
# 获取文件时统一转换为小写比较 image_files = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')]4.3 处理大量小文件
当数据集包含数万张图片时,简单的os.listdir()可能会很慢。可以考虑使用scandir:
from os import scandir def get_image_files_fast(image_dir): """更高效地获取图片文件列表""" with scandir(image_dir) as entries: return [entry.name for entry in entries if entry.is_file() and entry.name.lower().endswith('.jpg')]4.4 处理非标准命名
有时标注文件名可能与图片文件名不完全匹配,我们可以使用模糊匹配:
from difflib import get_close_matches def find_matching_gt_file(gt_dir, base_name): """模糊查找匹配的标注文件""" possible_gt = [f for f in os.listdir(gt_dir) if f.endswith('.txt')] matches = get_close_matches(f"GT_{base_name}.txt", possible_gt, n=1, cutoff=0.6) return matches[0] if matches else None5. 扩展到其他计算机视觉任务
虽然我们以人群计数为例,但类似的数据预处理流程也适用于其他任务:
5.1 目标检测任务
对于目标检测任务,标注通常是XML或JSON格式,我们可以修改脚本支持这些格式:
def generate_detection_list(dataset_path, output_file, annotation_ext='.xml'): image_dir = os.path.join(dataset_path, 'images') ann_dir = os.path.join(dataset_path, 'annotations') image_files = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')] with open(output_file, 'w') as f_out: for img_file in image_files: base_name = os.path.splitext(img_file)[0] ann_file = base_name + annotation_ext ann_path = os.path.join(ann_dir, ann_file) if os.path.exists(ann_path): f_out.write(f"{os.path.join(image_dir, img_file)} {ann_path}\n")5.2 语义分割任务
语义分割任务通常需要将图片与掩码图像配对:
def generate_segmentation_list(dataset_path, output_file): image_dir = os.path.join(dataset_path, 'images') mask_dir = os.path.join(dataset_path, 'masks') image_files = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')] with open(output_file, 'w') as f_out: for img_file in image_files: mask_file = os.path.splitext(img_file)[0] + '.png' # 假设掩码是png格式 mask_path = os.path.join(mask_dir, mask_file) if os.path.exists(mask_path): f_out.write(f"{os.path.join(image_dir, img_file)} {mask_path}\n")5.3 多任务学习
对于需要同时处理多种标注类型的多任务学习,我们可以扩展格式:
# 多任务list文件格式: 图片路径 标注1路径 标注2路径 ... f_out.write(f"{img_path} {gt_path} {det_path} {seg_path}\n")