OpenGait(步态识别框架)的配置项说明
一、核心配置模块解读
1. data_cfg(数据配置)
核心作用:定义数据集来源、加载方式、测试集等基础数据参数。
表格
| 参数 | 说明 | 示例 |
|---|---|---|
| dataset_name | 训练数据集名称(仅支持 CASIA-B/OUMVLP) | CASIA-B |
| dataset_root | 数据集存储路径 | /data/CASIA-B |
| dataset_partition | 数据集划分文件(划分训练 / 测试集) | ./datasets/CASIA-B/CASIA-B.json |
| num_workers | 数据加载的线程数 | 8(根据 CPU 核心数调整) |
| cache | 是否将数据全加载到内存(加速训练,需大内存) | True/False |
| test_dataset_name | 测试数据集名称 | CASIA-B |
2. loss_cfg(损失函数配置)
核心作用:定义训练使用的损失函数(支持 TripletLoss/CrossEntropyLoss),支持多损失加权。
表格
| 参数 | 说明 | 示例 |
|---|---|---|
| type | 损失函数类型 | TripletLoss/CrossEntropyLoss |
| loss_term_weight | 损失权重(多损失时调整各损失占比) | 1.0(TripletLoss)、0.1(CrossEntropyLoss) |
| log_prefix | 损失日志前缀(便于区分不同损失) | triplet/softmax |
| margin(TripletLoss 专属) | 三元组损失的边际值 | 0.2 |
| scale(CrossEntropyLoss 专属) | 交叉熵损失的缩放系数 | 16 |
3. optimizer_cfg(优化器配置)
核心作用:定义优化器类型及参数,对齐 PyTorch 原生优化器。
表格
| 参数 | 说明 | 示例 |
|---|---|---|
| solver | 优化器类型 | SGD/Adam |
| lr | 学习率 | 0.1(SGD)、1e-4(Adam) |
| momentum | 动量(SGD 专属) | 0.9 |
| weight_decay | 权重衰减(防止过拟合) | 0.0005 |
4. scheduler_cfg(学习率调度器)
核心作用:定义学习率衰减策略,对齐 PyTorch 原生调度器。
表格
| 参数 | 说明 | 示例 |
|---|---|---|
| scheduler | 调度器类型 | MultiStepLR/CosineAnnealingLR |
| milestones(MultiStepLR 专属) | 学习率衰减的迭代节点 | [20000, 40000] |
| gamma | 衰减系数 | 0.1(每次衰减为原学习率的 10%) |
5. model_cfg(模型配置)
核心作用:定义模型结构,需参考框架的 Model Library。
表格
| 核心参数 | 说明 | 示例 |
|---|---|---|
| model | 模型名称(如 Baseline、GaitSet) | Baseline |
| backbone_cfg | 骨干网络配置(通道数、层结构) | in_channels:1, layers_cfg: [BC-64, M, ...] |
| bin_num | 特征分箱数(步态特征编码) | [16,8,4,2,1] |
6. evaluator_cfg(评估器配置)
核心作用:定义模型评估的规则(推理方式、指标、 checkpoint 加载等)。
表格
| 关键参数 | 说明 | 示例 |
|---|---|---|
| restore_hint | 加载的 checkpoint (迭代数 / 路径) | 60000(加载第 6 万迭代的权重) |
| save_name | 实验名称(用于输出目录) | Baseline_CASIA-B |
| eval_func | 评估函数(CASIA-B 用 identification) | evaluate_indoor_dataset |
| sampler | 推理采样器配置 | type: InferenceSampler, sample_type: all_ordered |
| metric | 距离计算方式(euc 欧氏距离 /cos 余弦距离) | euc |
| transform | 数据预处理(切黑边 / 不切) | BaseSilCuttingTransform(切黑边) |
7. trainer_cfg(训练器配置)
核心作用:定义训练流程(迭代数、采样器、 checkpoint 保存、BN 同步等)。
表格
| 关键参数 | 说明 | 示例 |
|---|---|---|
| total_iter | 总训练迭代数 | 60000 |
| log_iter | 日志打印间隔 | 100(每 100 迭代打印一次) |
| save_iter | 权重保存间隔 | 10000(每 1 万迭代保存一次) |
| sampler | 训练采样器(TripletSampler) | batch_size: [8,16](8 个身份,每个身份 16 个序列) |
| sample_type | 训练帧采样方式 | fixed_unordered(固定帧数,随机选帧) |
| sync_BN | 多卡同步 BN | True(多卡训练建议开启) |
| with_test | 训练中是否穿插测试 | False(默认关闭,避免拖慢训练) |
二、关键参数重点说明
1. 采样器(sampler)核心参数
训练 / 评估的采样器是步态识别的核心,需重点理解:
表格
| 场景 | sample_type 取值 | 含义 |
|---|---|---|
| 训练 | fixed_unordered | 固定帧数(如 30 帧),随机选取(无序) |
| 训练 | unfixed_ordered | 帧数在 [min,max] 间随机,按自然顺序选帧 |
| 评估 | all_ordered | 用完整序列,按自然顺序输入(保证测试一致性) |
训练 batch_size 格式为[P,K]:
- P:一个 batch 中的身份数(如 8);
- K:每个身份的序列数(如 16);
- 需结合硬件显存调整(P×K 越大,显存占用越高)。
2. 输出目录规则
输出目录 =output/${dataset_name}/${model}/${save_name},例如:
output/CASIA-B/Baseline/Baseline_CASIA-B,包含:
- log:训练日志;
- checkpoint:模型权重;
- summary:可视化 / 评估结果。
3. 优先级规则
自定义配置会覆盖default.yaml中的默认配置,需注意:
- 若未定义某参数,自动使用default.yaml的默认值;
- 自定义参数与默认参数冲突时,以自定义为准。
三、配置文件编写规范
1. 基础结构
所有配置需按模块分层编写(data_cfg/loss_cfg/...),示例框架:
yaml
data_cfg: dataset_name: CASIA-B dataset_root: /path/to/CASIA-B num_workers: 8 dataset_partition: ./datasets/CASIA-B/CASIA-B.json loss_cfg: - type: TripletLoss loss_term_weight: 1.0 margin: 0.2 log_prefix: triplet - type: CrossEntropyLoss loss_term_weight: 0.1 scale: 16 log_prefix: softmax # 其他模块(optimizer/scheduler/model/evaluator/trainer)按上述规则补充2. 适配不同数据集的注意点
表格
| 数据集 | 关键调整项 |
|---|---|
| CASIA-B | eval_func: identification,metric: euc |
| OUMVLP | 增大 batch_size(如 [16,16]),调整 total_iter(如 120000) |
3. 多卡训练配置
需开启trainer_cfg.sync_BN: True,并调整num_workers(建议为卡数 ×4),示例:
yaml
trainer_cfg: sync_BN: True enable_float16: True # 混合精度训练,节省显存 data_cfg: num_workers: 16 # 2卡×8四、常见问题与调优建议
显存不足:
- 降低
trainer_cfg.sampler.batch_size(如从 [8,16] 改为 [4,8]); - 开启
enable_float16: True(混合精度); - 减小
frames_num_fixed(训练帧数,如从 30 改为 20)。
- 降低
训练不收敛:
- 调整 TripletLoss 的 margin(如 0.2→0.1);
- 增大学习率(SGD 从 0.1→0.2)或调整权重衰减;
- 检查数据集划分文件是否正确(dataset_partition)。
评估精度低:
- 评估时用
sample_type: all_ordered(完整序列); - 切换 metric(euc/cos),CASIA-B 优先 euc;
- 确保模型权重加载正确(restore_hint 路径 / 迭代数无误)。
- 评估时用
