行人轨迹预测入门:如何用ETH和UCY数据集训练你的第一个模型
行人轨迹预测实战指南:从ETH/UCY数据集到Baseline模型构建
行人轨迹预测作为计算机视觉和机器人导航领域的核心课题,正在智能监控、自动驾驶和社交机器人等场景中发挥越来越重要的作用。对于刚接触这一领域的研究者和工程师而言,ETH和UCY这两个经典数据集就像打开行人轨迹预测大门的钥匙——它们不仅提供了真实场景下的运动轨迹数据,更成为了学术界衡量模型性能的基准测试平台。本文将带您从零开始,逐步掌握这两个数据集的特性和使用方法,最终完成一个可运行的轨迹预测baseline模型。
1. 认识行人轨迹预测的核心数据集
1.1 ETH与UCY数据集概览
在行人轨迹预测研究中,ETH Walking Pedestrians (EWAP)和UCY crowds这对"黄金搭档"已经服务学术界超过十年。它们之所以经久不衰,关键在于:
- 真实场景采集:所有数据均来自欧洲城市真实 pedestrian zone 的监控视频,记录了行人自然的移动模式
- 精细标注:每帧中每个行人的位置、速度都被精确标注(ETH数据集达到2.5fps,UCY为2fps)
- 社交互动丰富:包含大量行人相遇、避让、群组移动等复杂社交行为场景
表:两个数据集的基本参数对比
| 特性 | ETH数据集 | UCY数据集 |
|---|---|---|
| 采集地点 | 苏黎世ETH校园 | 塞浦路斯大学 |
| 场景数量 | 2个(eth, hotel) | 3个(student, univ, zara) |
| 平均行人数量 | 4.6人/帧 | 4.2人/帧 |
| 标注频率 | 2.5Hz | 2Hz |
| 特殊属性 | 包含单应性矩阵 | 包含视线方向 |
1.2 数据获取与预处理要点
获取这两个数据集需要一些技巧:
# 推荐的数据获取路径 wget https://data.vision.ee.ethz.ch/cvl/aess/dataset/ewap_dataset_full.tgz # ETH完整版 wget https://graphics.cs.ucy.ac.cy/research/downloads/crowd-data # UCY官方源数据解压后,您会看到以下关键文件:
- obsmat.txt:主要轨迹数据文件
- H.txt:单应性矩阵(用于图像坐标到世界坐标的转换)
- README.md:各字段的详细说明
注意:原始数据中的Z轴信息(pos_z, v_z)在实际研究中通常被忽略,因为行人运动主要在二维平面
2. 数据工程:从原始数据到模型输入
2.1 轨迹数据解析实战
让我们用Python代码实际解析一个ETH数据片段:
import numpy as np def load_eth_data(file_path): data = np.loadtxt(file_path) # 列说明: [帧号 行人ID x坐标 y坐标 x速度 y速度] trajectories = {} for row in data: ped_id = int(row[1]) if ped_id not in trajectories: trajectories[ped_id] = [] trajectories[ped_id].append([row[0], row[2], row[3], row[4], row[5]]) return trajectories这个简单的解析器会将原始数据转换为以行人ID为键的字典,每个值是该行人的完整轨迹序列。
2.2 关键预处理步骤
- 坐标转换:使用H.txt中的单应性矩阵将图像坐标转为世界坐标(单位:米)
- 轨迹切片:将连续轨迹切割为8秒观察+4.8秒预测的标准片段
- 归一化处理:对坐标进行zero-mean归一化
- 社交关系构建:计算行人间的相对距离和运动方向
表:预处理中的典型参数设置
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 观察长度 | 20帧(8秒) | 历史轨迹窗口 |
| 预测长度 | 12帧(4.8秒) | 需要预测的未来窗口 |
| 采样间隔 | 0.4秒 | ETH/UCY的标准帧间隔 |
| 交互半径 | 3米 | 考虑社交影响的阈值距离 |
3. Baseline模型构建与训练
3.1 模型架构选择
对于初学者,我推荐从这两种基础架构入手:
线性回归Baseline
- 简单加权历史位置预测未来轨迹
- 训练速度快,可作为性能下限参考
LSTM社交模型
- 使用LSTM编码个体轨迹
- 通过池化层(Pooling)捕捉社交交互
- 平衡复杂度和预测精度
import torch import torch.nn as nn class SocialLSTM(nn.Module): def __init__(self, input_dim=2, hidden_dim=64): super().__init__() self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True) self.pool = nn.MaxPool1d(kernel_size=5) self.fc = nn.Linear(hidden_dim, 24) # 预测12个时间点的(x,y) def forward(self, x): # x形状: [batch, seq_len, input_dim] out, _ = self.lstm(x) out = self.pool(out.transpose(1, 2)).squeeze() return self.fc(out)3.2 训练技巧与评估
训练行人轨迹预测模型时,有几个关键注意事项:
- 损失函数选择:平均位移误差(ADE)和最终位移误差(FDE)是标准指标
- 数据增强:对轨迹进行随机旋转和缩放提升泛化性
- 验证策略:采用leave-one-out交叉验证(在4个场景上训练,剩下1个测试)
提示:初学者常犯的错误是只关注个体轨迹而忽略社交因素。即使简单模型,加入基本的邻居信息也能显著提升性能
4. 进阶方向与性能优化
4.1 从Baseline到SOTA的路径
当您掌握了基础模型后,可以考虑以下进阶方向:
- 社交注意力机制:让模型自动学习关注最重要的邻居
- 场景约束融合:引入地图信息避免不合理的预测(如穿过墙壁)
- 多模态预测:输出多种可能的未来轨迹及其概率
- 生成式模型:使用GAN或Diffusion模型生成更真实的轨迹
4.2 实际项目中的调优经验
在真实应用中,我们发现这些策略特别有效:
- 轨迹平滑:对原始数据应用卡尔曼滤波减少标注噪声
- 速度归一化:对不同行人类型(成人/儿童)进行速度标准化
- 早停策略:当验证集FDE在10个epoch内无改进时停止训练
表:典型baseline模型在ETH/UCY上的性能参考
| 模型 | ADE(m) | FDE(m) | 训练时间 |
|---|---|---|---|
| 线性回归 | 1.23 | 2.45 | <1分钟 |
| 社交LSTM | 0.87 | 1.62 | ~2小时 |
| 社交GAN | 0.61 | 1.18 | ~8小时 |
5. 常见问题与解决方案
在实际项目开发中,这些经验可能帮您少走弯路:
- 数据不平衡问题:UCY的zara场景行人密度明显高于其他场景,建议训练时进行样本加权
- 坐标跳跃异常:检查是否正确处理了单应性矩阵转换
- GPU内存不足:减小batch size或使用梯度累积
- 过拟合现象:增加dropout率或添加L2正则化
# 解决过拟合的模型配置示例 model = SocialLSTM() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4) # L2正则化 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') # 动态学习率行人轨迹预测的魅力在于它完美结合了计算机视觉、机器学习和行为心理学。当我第一次看到自己训练的模型成功预测出人群分流模式时,那种成就感至今难忘。建议初学者不要急于复现复杂论文,而是先扎实理解数据特性——在ETH酒店场景中观察行人如何优雅地避让,或者在UCY校园场景分析学生群体的移动规律,这些直觉对模型设计至关重要。
