从零到一:CVPR2024 HAT模型复现全流程与避坑指南
1. 环境搭建:避开版本冲突的雷区
复现HAT模型的第一步就是搭建开发环境,这也是最容易踩坑的环节。我最初直接按照官方requirements.txt安装依赖,结果在PyTorch和CUDA版本兼容性上栽了跟头。经过多次尝试,发现以下组合最稳定:
conda create -n hat python=3.9 conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 -c pytorch这里有个隐藏陷阱:官方代码虽然支持PyTorch 2.x,但实际测试发现3.10以上Python版本会导致MSDeformableAttn编译失败。建议先用docker拉取基础镜像再配置环境:
FROM nvidia/cuda:11.7.1-devel-ubuntu20.04 RUN apt-get update && apt-get install -y git python3.9Detectron2的安装更是个技术活。直接pip install会报错,必须从源码编译。我推荐先卸载已有版本再重装:
git clone https://github.com/facebookresearch/detectron2.git cd detectron2 && pip install -e .遇到"Could not build wheels for detectron2"错误时,大概率是gcc版本问题。Ubuntu 20.04需要手动安装g++-9:
sudo apt install g++-9 export CXX=/usr/bin/g++-92. 数据准备:COCO-Search18的变形记
官方文档对数据准备的描述比较模糊,我花了三天时间才理清完整流程。关键点在于数据集需要三种形态转换:
- 图像尺寸转换:原始1680x1050需压缩到512x320
- 标注文件合并:train/val/test三个JSON要合成一个
- 语义文件移植:要从Gazeformer项目迁移语义标签
用Pillow批量处理图像时,注意保持宽高比避免变形。这里给出优化后的resize脚本:
from PIL import Image import os def smart_resize(src_dir, dst_dir, size=(512,320)): os.makedirs(dst_dir, exist_ok=True) for img_name in os.listdir(src_dir): img = Image.open(f"{src_dir}/{img_name}") # 保持宽高比的resize img.thumbnail(size, Image.Resampling.LANCZOS) img.save(f"{dst_dir}/{img_name}")JSON合并时有个坑:不同split的字段结构可能不一致。建议先用json.tool验证格式:
python -m json.tool coco_search_fixations_512x320_val.json3. 模型训练:参数调优实战录
开始训练前务必检查config文件,我修改了这些关键参数:
{ "batch_size": 8, // 显存不足时可降至4 "num_workers": 4, // 根据CPU核心数调整 "lr": 1e-4, // 初始学习率 "max_epochs": 50, // 早停设为30 "patience": 5 // 验证集loss不降时停止 }启动训练时推荐用nohup记录日志:
nohup python train.py --hparams configs/coco_search18.json \ --dataset-root ./data > train.log 2>&1 &遇到显存爆炸(OOM)问题时,可以尝试梯度累积:
# 在train.py中加入 accumulate_steps = 4 loss.backward() if step % accumulate_steps == 0: optimizer.step() optimizer.zero_grad()4. 结果可视化:让注意力轨迹说话
官方没有提供可视化工具,我基于matplotlib开发了动态扫描路径展示:
def plot_animated_scanpath(img, xs, ys): fig, ax = plt.subplots() ax.imshow(img) line, = ax.plot([], [], 'y-', lw=2) dots = ax.scatter([], [], c='red', s=50) def init(): line.set_data([], []) dots.set_offsets(np.empty((0,2))) return line, dots def update(frame): x_data = xs[:frame+1] y_data = ys[:frame+1] line.set_data(x_data, y_data) dots.set_offsets(np.c_[x_data,y_data]) return line, dots ani = FuncAnimation(fig, update, frames=len(xs), init_func=init, blit=True) plt.close() return ani保存动画建议用HTML格式,方便在Jupyter中直接查看:
ani.save('scanpath.html', writer='html', fps=2)5. 避坑宝典:血泪经验总结
CUDA版本问题:如果遇到"undefined symbol: _ZN3c105ErrorC1ENS_14SourceLocationESs"错误,说明CUDA和PyTorch版本不匹配。解决步骤:
- 确认CUDA版本:
nvcc --version - 安装对应PyTorch版本
- 清除缓存:
rm -rf ~/.cache/torch
- 确认CUDA版本:
数据加载卡顿:当发现数据加载成为瓶颈时,可以:
- 将数据预加载到内存
- 使用更快的存储设备(如NVMe SSD)
- 增加num_workers数量(不要超过CPU核心数)
评估指标异常:语义得分全为零的情况,检查:
- segmentation_maps目录是否完整
- 文件命名是否与代码中的硬编码匹配
- numpy文件是否损坏(用np.load测试)
多GPU训练问题:使用DataParallel时注意:
- batch_size要能被GPU数量整除
- 模型参数要在主GPU上初始化
- 使用
torch.cuda.empty_cache()定期清理显存
在模型微调阶段,我发现增加学习率warmup能显著提升稳定性。具体实现:
from torch.optim.lr_scheduler import LambdaLR def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): def f(x): if x >= warmup_iters: return 1 alpha = float(x) / warmup_iters return warmup_factor * (1 - alpha) + alpha return LambdaLR(optimizer, f)