保姆级教程:用Python脚本搞定CelebAMask-HQ数据集预处理与可视化(附完整代码)
从零掌握CelebAMask-HQ数据集:完整预处理与可视化实战指南
当第一次打开CelebAMask-HQ数据集时,面对分散在多个文件夹中的19类面部组件mask文件,即使是经验丰富的开发者也可能感到无从下手。这个包含3万张高分辨率人脸图像的数据集,每张图都精确标注了皮肤、鼻子、眼睛等19个面部区域的语义分割mask,是进行人脸编辑、虚拟化妆等计算机视觉任务的黄金标准。但原始数据组织方式让直接使用变得困难——我们需要将这些零散的标注文件整合成结构化数据,并生成直观的可视化效果。
1. 环境准备与数据整理
在开始处理前,我们需要确保Python环境已安装必要的库。建议使用conda创建专属环境:
conda create -n celebamask python=3.8 conda activate celebamask pip install opencv-python pillow numpy pandas数据集目录结构应该如下安排:
CelebAMask-HQ/ ├── CelebA-HQ-img/ # 原始图像 ├── CelebAMask-HQ-mask-anno/ # 原始标注 ├── CelebA-HQ-to-CelebA-mapping.txt └── Data_preprocessing/ # 我们的处理脚本关键提示:从GitHub克隆仓库时,注意检查
CelebAMask-HQ-mask-anno文件夹是否完整。常见问题是部分mask文件下载失败,导致后续处理报错。
2. Mask整合:从碎片到完整标注
原始数据将每个面部组件的mask单独保存,我们需要将它们合并为每张图一个完整mask。创建g_mask.py:
import os import cv2 import numpy as np # 19个面部组件的定义顺序很重要,影响后续可视化颜色映射 LABELS = [ 'skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth' ] def merge_masks(base_dir, output_dir, img_count=30000): os.makedirs(output_dir, exist_ok=True) for img_id in range(img_count): # 每2000张图一个子文件夹 subfolder = str(img_id // 2000) base_mask = np.zeros((512, 512), dtype=np.uint8) for idx, label in enumerate(LABELS): mask_path = os.path.join( base_dir, subfolder, f"{img_id:05d}_{label}.png" ) if os.path.exists(mask_path): mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) base_mask[mask > 0] = idx + 1 # 为每个组件分配唯一编号 cv2.imwrite(f"{output_dir}/{img_id}.png", base_mask)运行后会生成包含完整mask的CelebAMaskHQ-mask目录。常见问题排查:
- 报错:"NoneType' object has no attribute 'shape'"
- 原因:mask文件缺失或路径错误
- 解决:检查
CelebAMask-HQ-mask-anno下的文件名是否匹配00000_skin.png格式
3. 彩色可视化:让标注一目了然
黑白mask不便于直观检查,我们通过g_color.py添加颜色:
from PIL import Image import numpy as np # 为每个面部组件定义醒目颜色 COLOR_MAP = [ [0,0,0], # 背景 [204,0,0], # 皮肤-红 [76,153,0], # 鼻子-绿 [204,204,0], # 眼周-黄 [51,51,255], # 左眼-蓝 [204,0,204], # 右眼-紫 [0,255,255], # 左眉-青 [255,204,204], # 右眉-粉 [102,51,0], # 左耳-棕 [255,0,0], # 右耳-亮红 [102,204,0], # 嘴-黄绿 [255,255,0], # 上唇-亮黄 [0,0,153], # 下唇-深蓝 [0,0,204], # 头发-蓝 [255,51,153], # 帽子-粉红 [0,204,204], # 耳环-青绿 [0,51,0], # 颈部-深绿 [255,153,51], # 衣领-橙 [0,204,0] # 衣服-亮绿 ] def colorize_masks(mask_dir, output_dir): os.makedirs(output_dir, exist_ok=True) for mask_file in os.listdir(mask_dir): mask_path = os.path.join(mask_dir, mask_file) colored = np.zeros((512,512,3), dtype=np.uint8) mask = np.array(Image.open(mask_path)) for idx, color in enumerate(COLOR_MAP): colored[mask == idx] = color Image.fromarray(colored).save(f"{output_dir}/{mask_file}")生成的效果图中,不同面部区域会呈现鲜明对比色。调试技巧:
- 如果某些区域颜色异常,检查
COLOR_MAP顺序是否与LABELS完全对应 - 使用
PIL.Image.show()快速预览结果
4. 数据集划分:与CelebA保持一致
为确保与现有研究可比性,我们按CelebA的原始划分方式拆分数据。创建g_partition.py:
import pandas as pd from shutil import copyfile def split_dataset(mask_dir, img_dir, mapping_file): # 创建输出目录 splits = ['train', 'val', 'test'] for split in splits: os.makedirs(f"{split}_label", exist_ok=True) os.makedirs(f"{split}_img", exist_ok=True) # 读取CelebA的原始划分映射 mapping = pd.read_csv(mapping_file, delim_whitespace=True, header=None) # 按CelebA的划分标准进行分配 for idx, celebA_id in enumerate(mapping[1]): if 162771 <= celebA_id < 182638: # 验证集 dest = 'val' elif celebA_id >= 182638: # 测试集 dest = 'test' else: # 训练集 dest = 'train' # 复制图像和标注 copyfile(f"{mask_dir}/{idx}.png", f"{dest}_label/{idx}.png") copyfile(f"{img_dir}/{idx}.jpg", f"{dest}_img/{idx}.jpg")最终得到的数据集结构:
train_img/ # 约24,000张训练图像 train_label/ # 对应的训练标注 val_img/ # 约3,000张验证图像 val_label/ test_img/ # 约3,000张测试图像 test_label/5. 高级技巧与问题解决方案
在实际处理过程中,有几个关键点需要特别注意:
文件名不匹配问题原始脚本中可能存在路径硬编码问题。现代Python推荐使用pathlib进行路径操作:
from pathlib import Path mask_path = Path(base_dir) / subfolder / f"{img_id:05d}_{label}.png" if mask_path.exists(): mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)批量处理加速对于3万张图像,可以使用多进程加速:
from multiprocessing import Pool def process_image(args): img_id, base_dir, output_dir = args # ...处理逻辑... with Pool(8) as p: # 使用8个进程 p.map(process_image, [(i, base_dir, output_dir) for i in range(30000)])可视化检查工具创建Jupyter notebook快速检查任意图像的处理结果:
import matplotlib.pyplot as plt def show_sample(img_id): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,6)) img = plt.imread(f"CelebA-HQ-img/{img_id}.jpg") mask = plt.imread(f"CelebAMask-HQ-mask-color/{img_id}.png") ax1.imshow(img) ax2.imshow(mask) plt.show()这套预处理流程不仅适用于CelebAMask-HQ,其设计思路也可迁移到其他分割数据集的处理中。关键在于理解数据组织逻辑、建立可靠的错误检查机制,以及提供直观的可视化验证手段。
