深度学习炼丹师的效率神器:手把手教你用Shell脚本批量跑模型(附argparse配置模板)
深度学习炼丹师的效率神器:手把手教你用Shell脚本批量跑模型(附argparse配置模板)
在深度学习模型开发中,我们常常需要反复调整超参数、更换模型架构或切换数据集进行测试。每次手动修改代码或命令行参数不仅效率低下,还容易出错。本文将介绍如何通过Shell脚本+argparse的组合拳,实现一键式多模型训练与参数网格搜索,让你的"炼丹"过程既高效又优雅。
1. argparse:Python脚本的参数化基石
argparse是Python标准库中的命令行参数解析模块,它能将脚本中的关键参数暴露给用户,实现运行时动态配置。一个典型的深度学习训练脚本通常包含以下核心参数:
import argparse def parse_args(): parser = argparse.ArgumentParser(description='模型训练参数配置') # 训练流程参数 parser.add_argument('--epochs', type=int, default=50, help='训练轮次') parser.add_argument('--batch_size', type=int, default=32, help='批次大小') parser.add_argument('--lr', type=float, default=1e-3, help='初始学习率') # 模型架构参数 parser.add_argument('--model', type=str, default='resnet18', help='模型名称(resnet18/densenet121)') parser.add_argument('--pretrained', action='store_true', help='是否使用预训练权重') # 数据相关参数 parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录') parser.add_argument('--num_workers', type=int, default=4, help='数据加载线程数') return parser.parse_args() if __name__ == '__main__': args = parse_args() print(f'当前配置:{vars(args)}')提示:使用
required=True标记必须参数,避免遗漏关键配置;action='store_true'用于创建布尔型开关参数。
在脚本中使用这些参数时,只需通过args.参数名调用:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)2. Shell脚本:批量执行的瑞士军刀
当我们需要测试不同模型架构或超参数组合时,手动逐个执行命令显然不够高效。Shell脚本可以完美解决这个问题,下面是一个基础模板:
#!/bin/bash # 定义公共参数 DATA_DIR="./dataset/cifar10" EPOCHS=50 BATCH_SIZE=128 # 模型列表 MODELS=("resnet18" "densenet121" "efficientnet_b0") # 学习率列表 LEARNING_RATES=(1e-3 5e-4 1e-4) for model in "${MODELS[@]}"; do for lr in "${LEARNING_RATES[@]}"; do echo "正在训练:model=${model}, lr=${lr}" python train.py \ --data_dir $DATA_DIR \ --epochs $EPOCHS \ --batch_size $BATCH_SIZE \ --model $model \ --lr $lr \ --output_dir "logs/${model}_lr${lr}" done done这个脚本实现了:
- 自动遍历3种模型架构和3种学习率组合
- 每次训练生成独立的输出目录
- 实时打印当前训练配置
3. 高级技巧:参数网格搜索与实验管理
3.1 嵌套循环实现多参数组合
通过嵌套循环,我们可以轻松实现多参数的网格搜索:
#!/bin/bash # 定义搜索空间 BATCH_SIZES=(32 64 128) LEARNING_RATES=(1e-2 1e-3 1e-4) OPTIMIZERS=("adam" "sgd") for bs in "${BATCH_SIZES[@]}"; do for lr in "${LEARNING_RATES[@]}"; do for opt in "${OPTIMIZERS[@]}"; do EXP_NAME="bs${bs}_lr${lr}_${opt}" echo "启动实验:${EXP_NAME}" python train.py \ --batch_size $bs \ --lr $lr \ --optimizer $opt \ --experiment_name $EXP_NAME done done done3.2 实验结果的自动归档
为每个实验创建独立的日志目录是良好实践:
#!/bin/bash LOG_ROOT="./experiments" TIMESTAMP=$(date +"%Y%m%d_%H%M%S") for model in "resnet18" "resnet34"; do LOG_DIR="${LOG_ROOT}/${TIMESTAMP}_${model}" mkdir -p $LOG_DIR python train.py \ --model $model \ --log_dir $LOG_DIR \ 2>&1 | tee "${LOG_DIR}/train.log" done关键点:
date +"%Y%m%d_%H%M%S"生成时间戳保证目录唯一性mkdir -p自动创建目录tee命令同时输出到屏幕和日志文件
3.3 并行执行加速实验
使用&和wait实现有限并行:
#!/bin/bash MAX_JOBS=2 # 同时运行的任务数 CURRENT_JOBS=0 for lr in 1e-3 5e-4 1e-4; do ((CURRENT_JOBS++)) python train.py --lr $lr --job_id $CURRENT_JOBS & if (( CURRENT_JOBS == MAX_JOBS )); then wait CURRENT_JOBS=0 fi done wait # 等待所有任务完成注意:并行执行需确保GPU内存充足,或使用
CUDA_VISIBLE_DEVICES分配不同GPU
4. 实用模板库:常见深度学习任务脚本
4.1 模型对比测试模板
#!/bin/bash # 模型测试对比脚本 DATA_DIR="./data/imagenet" CONFIG="./configs/base.yaml" declare -A MODEL_CONFIGS=( ["resnet50"]="arch=resnet50,pretrained=true" ["vit_base"]="arch=vit,img_size=384" ["swin_tiny"]="arch=swin,window_size=7" ) for model in "${!MODEL_CONFIGS[@]}"; do echo "测试模型:${model}" python test.py \ --data_dir $DATA_DIR \ --config $CONFIG \ --model $model \ --model_config "${MODEL_CONFIGS[$model]}" \ --output_file "results/${model}_metrics.json" done4.2 跨数据集评估模板
#!/bin/bash MODEL_PATH="./checkpoints/best_model.pth" DATASETS=("cifar10" "cifar100" "svhn") for dataset in "${DATASETS[@]}"; do python evaluate.py \ --dataset $dataset \ --data_root "./data/${dataset}" \ --model $MODEL_PATH \ --batch_size 64 \ --metrics "accuracy,precision,recall,f1" \ --save_to "eval_results/${dataset}_report.csv" done4.3 超参数优化模板
#!/bin/bash # 学习率与优化器组合搜索 for lr in 1e-2 5e-3 1e-3; do for wd in 0 1e-4 1e-3; do python train.py \ --lr $lr \ --weight_decay $wd \ --config "configs/hparam_search.yaml" \ --run_name "lr${lr}_wd${wd}" done done5. 错误处理与日志增强
5.1 添加错误检查机制
#!/bin/bash set -e # 遇到错误立即退出 function train_model() { local model=$1 local lr=$2 echo "[$(date)] 开始训练:${model} (lr=${lr})" if ! python train.py --model $model --lr $lr; then echo "[ERROR] 训练失败:${model}" return 1 fi echo "[$(date)] 训练完成:${model}" return 0 } # 调用示例 train_model "resnet18" 1e-3 || exit 15.2 结构化日志记录
#!/bin/bash log() { local level=$1 local message=$2 echo "$(date '+%Y-%m-%d %H:%M:%S') [${level}] ${message}" } log "INFO" "开始实验流程" for seed in 42 123 456; do log "DEBUG" "使用随机种子:${seed}" python train.py --seed $seed 2>&1 | tee "seed_${seed}.log" if [ ${PIPESTATUS[0]} -ne 0 ]; then log "ERROR" "种子 ${seed} 运行失败" exit 1 fi done log "INFO" "所有实验完成"6. 可视化与结��分析
6.1 训练指标自动汇总
#!/bin/bash # 生成CSV格式的结果摘要 echo "model,lr,batch_size,final_acc,training_time" > results/summary.csv for log_file in logs/*.log; do model=$(grep "Model:" $log_file | awk '{print $2}') lr=$(grep "Learning rate:" $log_file | awk '{print $3}') acc=$(grep "Final accuracy:" $log_file | awk '{print $3}') time=$(grep "Training time:" $log_file | awk '{print $3}') echo "$model,$lr,$acc,$time" >> results/summary.csv done # 使用pandas生成分析报告 python -c " import pandas as pd df = pd.read_csv('results/summary.csv') print(df.describe().to_markdown()) " > results/analysis.md6.2 实验结果对比表格
生成Markdown格式的对比表格:
#!/bin/bash cat << EOF > results/comparison.md # 模型性能对比 | 模型名称 | 准确率 | 训练时间 | 参数量 | |---------|--------|---------|--------| $(for dir in experiments/*; do model=$(basename $dir) acc=$(cat $dir/metrics.json | jq '.accuracy') time=$(cat $dir/metrics.json | jq '.training_time') params=$(cat $dir/metrics.json | jq '.parameters') echo "| $model | $acc | $time | $params |" done) EOF7. 进阶技巧:动态参数生成
7.1 从配置文件生成参数
#!/bin/bash # 读取JSON配置生成训练命令 CONFIG_FILE="configs/experiments.json" jq -c '.experiments[]' $CONFIG_FILE | while read experiment; do name=$(echo $experiment | jq -r '.name') lr=$(echo $experiment | jq -r '.lr') bs=$(echo $experiment | jq -r '.batch_size') python train.py \ --experiment_name $name \ --lr $lr \ --batch_size $bs \ --config "configs/base.yaml" done7.2 条件参数组合
#!/bin/bash # 根据条件生成不同参数组合 for model in "resnet18" "resnet34"; do if [ "$model" == "resnet18" ]; then lr_list=(1e-3 5e-4) bs_list=(64 128) else lr_list=(5e-4 1e-4) bs_list=(32 64) fi for lr in "${lr_list[@]}"; do for bs in "${bs_list[@]}"; do python train.py \ --model $model \ --lr $lr \ --batch_size $bs done done done