当前位置: 首页 > news >正文

深度学习炼丹师的效率神器:手把手教你用Shell脚本批量跑模型(附argparse配置模板)

深度学习炼丹师的效率神器:手把手教你用Shell脚本批量跑模型(附argparse配置模板)

在深度学习模型开发中,我们常常需要反复调整超参数、更换模型架构或切换数据集进行测试。每次手动修改代码或命令行参数不仅效率低下,还容易出错。本文将介绍如何通过Shell脚本+argparse的组合拳,实现一键式多模型训练与参数网格搜索,让你的"炼丹"过程既高效又优雅。

1. argparse:Python脚本的参数化基石

argparse是Python标准库中的命令行参数解析模块,它能将脚本中的关键参数暴露给用户,实现运行时动态配置。一个典型的深度学习训练脚本通常包含以下核心参数:

import argparse def parse_args(): parser = argparse.ArgumentParser(description='模型训练参数配置') # 训练流程参数 parser.add_argument('--epochs', type=int, default=50, help='训练轮次') parser.add_argument('--batch_size', type=int, default=32, help='批次大小') parser.add_argument('--lr', type=float, default=1e-3, help='初始学习率') # 模型架构参数 parser.add_argument('--model', type=str, default='resnet18', help='模型名称(resnet18/densenet121)') parser.add_argument('--pretrained', action='store_true', help='是否使用预训练权重') # 数据相关参数 parser.add_argument('--data_dir', type=str, required=True, help='数据集根目录') parser.add_argument('--num_workers', type=int, default=4, help='数据加载线程数') return parser.parse_args() if __name__ == '__main__': args = parse_args() print(f'当前配置:{vars(args)}')

提示:使用required=True标记必须参数,避免遗漏关键配置;action='store_true'用于创建布尔型开关参数。

在脚本中使用这些参数时,只需通过args.参数名调用:

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)

2. Shell脚本:批量执行的瑞士军刀

当我们需要测试不同模型架构或超参数组合时,手动逐个执行命令显然不够高效。Shell脚本可以完美解决这个问题,下面是一个基础模板:

#!/bin/bash # 定义公共参数 DATA_DIR="./dataset/cifar10" EPOCHS=50 BATCH_SIZE=128 # 模型列表 MODELS=("resnet18" "densenet121" "efficientnet_b0") # 学习率列表 LEARNING_RATES=(1e-3 5e-4 1e-4) for model in "${MODELS[@]}"; do for lr in "${LEARNING_RATES[@]}"; do echo "正在训练:model=${model}, lr=${lr}" python train.py \ --data_dir $DATA_DIR \ --epochs $EPOCHS \ --batch_size $BATCH_SIZE \ --model $model \ --lr $lr \ --output_dir "logs/${model}_lr${lr}" done done

这个脚本实现了:

  • 自动遍历3种模型架构和3种学习率组合
  • 每次训练生成独立的输出目录
  • 实时打印当前训练配置

3. 高级技巧:参数网格搜索与实验管理

3.1 嵌套循环实现多参数组合

通过嵌套循环,我们可以轻松实现多参数的网格搜索:

#!/bin/bash # 定义搜索空间 BATCH_SIZES=(32 64 128) LEARNING_RATES=(1e-2 1e-3 1e-4) OPTIMIZERS=("adam" "sgd") for bs in "${BATCH_SIZES[@]}"; do for lr in "${LEARNING_RATES[@]}"; do for opt in "${OPTIMIZERS[@]}"; do EXP_NAME="bs${bs}_lr${lr}_${opt}" echo "启动实验:${EXP_NAME}" python train.py \ --batch_size $bs \ --lr $lr \ --optimizer $opt \ --experiment_name $EXP_NAME done done done

3.2 实验结果的自动归档

为每个实验创建独立的日志目录是良好实践:

#!/bin/bash LOG_ROOT="./experiments" TIMESTAMP=$(date +"%Y%m%d_%H%M%S") for model in "resnet18" "resnet34"; do LOG_DIR="${LOG_ROOT}/${TIMESTAMP}_${model}" mkdir -p $LOG_DIR python train.py \ --model $model \ --log_dir $LOG_DIR \ 2>&1 | tee "${LOG_DIR}/train.log" done

关键点:

  • date +"%Y%m%d_%H%M%S"生成时间戳保证目录唯一性
  • mkdir -p自动创建目录
  • tee命令同时输出到屏幕和日志文件

3.3 并行执行加速实验

使用&wait实现有限并行:

#!/bin/bash MAX_JOBS=2 # 同时运行的任务数 CURRENT_JOBS=0 for lr in 1e-3 5e-4 1e-4; do ((CURRENT_JOBS++)) python train.py --lr $lr --job_id $CURRENT_JOBS & if (( CURRENT_JOBS == MAX_JOBS )); then wait CURRENT_JOBS=0 fi done wait # 等待所有任务完成

注意:并行执行需确保GPU内存充足,或使用CUDA_VISIBLE_DEVICES分配不同GPU

4. 实用模板库:常见深度学习任务脚本

4.1 模型对比测试模板

#!/bin/bash # 模型测试对比脚本 DATA_DIR="./data/imagenet" CONFIG="./configs/base.yaml" declare -A MODEL_CONFIGS=( ["resnet50"]="arch=resnet50,pretrained=true" ["vit_base"]="arch=vit,img_size=384" ["swin_tiny"]="arch=swin,window_size=7" ) for model in "${!MODEL_CONFIGS[@]}"; do echo "测试模型:${model}" python test.py \ --data_dir $DATA_DIR \ --config $CONFIG \ --model $model \ --model_config "${MODEL_CONFIGS[$model]}" \ --output_file "results/${model}_metrics.json" done

4.2 跨数据集评估模板

#!/bin/bash MODEL_PATH="./checkpoints/best_model.pth" DATASETS=("cifar10" "cifar100" "svhn") for dataset in "${DATASETS[@]}"; do python evaluate.py \ --dataset $dataset \ --data_root "./data/${dataset}" \ --model $MODEL_PATH \ --batch_size 64 \ --metrics "accuracy,precision,recall,f1" \ --save_to "eval_results/${dataset}_report.csv" done

4.3 超参数优化模板

#!/bin/bash # 学习率与优化器组合搜索 for lr in 1e-2 5e-3 1e-3; do for wd in 0 1e-4 1e-3; do python train.py \ --lr $lr \ --weight_decay $wd \ --config "configs/hparam_search.yaml" \ --run_name "lr${lr}_wd${wd}" done done

5. 错误处理与日志增强

5.1 添加错误检查机制

#!/bin/bash set -e # 遇到错误立即退出 function train_model() { local model=$1 local lr=$2 echo "[$(date)] 开始训练:${model} (lr=${lr})" if ! python train.py --model $model --lr $lr; then echo "[ERROR] 训练失败:${model}" return 1 fi echo "[$(date)] 训练完成:${model}" return 0 } # 调用示例 train_model "resnet18" 1e-3 || exit 1

5.2 结构化日志记录

#!/bin/bash log() { local level=$1 local message=$2 echo "$(date '+%Y-%m-%d %H:%M:%S') [${level}] ${message}" } log "INFO" "开始实验流程" for seed in 42 123 456; do log "DEBUG" "使用随机种子:${seed}" python train.py --seed $seed 2>&1 | tee "seed_${seed}.log" if [ ${PIPESTATUS[0]} -ne 0 ]; then log "ERROR" "种子 ${seed} 运行失败" exit 1 fi done log "INFO" "所有实验完成"

6. 可视化与结��分析

6.1 训练指标自动汇总

#!/bin/bash # 生成CSV格式的结果摘要 echo "model,lr,batch_size,final_acc,training_time" > results/summary.csv for log_file in logs/*.log; do model=$(grep "Model:" $log_file | awk '{print $2}') lr=$(grep "Learning rate:" $log_file | awk '{print $3}') acc=$(grep "Final accuracy:" $log_file | awk '{print $3}') time=$(grep "Training time:" $log_file | awk '{print $3}') echo "$model,$lr,$acc,$time" >> results/summary.csv done # 使用pandas生成分析报告 python -c " import pandas as pd df = pd.read_csv('results/summary.csv') print(df.describe().to_markdown()) " > results/analysis.md

6.2 实验结果对比表格

生成Markdown格式的对比表格:

#!/bin/bash cat << EOF > results/comparison.md # 模型性能对比 | 模型名称 | 准确率 | 训练时间 | 参数量 | |---------|--------|---------|--------| $(for dir in experiments/*; do model=$(basename $dir) acc=$(cat $dir/metrics.json | jq '.accuracy') time=$(cat $dir/metrics.json | jq '.training_time') params=$(cat $dir/metrics.json | jq '.parameters') echo "| $model | $acc | $time | $params |" done) EOF

7. 进阶技巧:动态参数生成

7.1 从配置文件生成参数

#!/bin/bash # 读取JSON配置生成训练命令 CONFIG_FILE="configs/experiments.json" jq -c '.experiments[]' $CONFIG_FILE | while read experiment; do name=$(echo $experiment | jq -r '.name') lr=$(echo $experiment | jq -r '.lr') bs=$(echo $experiment | jq -r '.batch_size') python train.py \ --experiment_name $name \ --lr $lr \ --batch_size $bs \ --config "configs/base.yaml" done

7.2 条件参数组合

#!/bin/bash # 根据条件生成不同参数组合 for model in "resnet18" "resnet34"; do if [ "$model" == "resnet18" ]; then lr_list=(1e-3 5e-4) bs_list=(64 128) else lr_list=(5e-4 1e-4) bs_list=(32 64) fi for lr in "${lr_list[@]}"; do for bs in "${bs_list[@]}"; do python train.py \ --model $model \ --lr $lr \ --batch_size $bs done done done
http://www.jsqmd.com/news/899880/

相关文章:

  • Swin Transformer实战:从零搭建PyTorch图像分类模型
  • 别再只用摇杆移动角色了!解锁Joystick Pack的5个隐藏用法:控制UI、镜头旋转与场景交互
  • 基于CODESYS与EtherCAT的步进电机单轴运动控制实践
  • 理工科毕业生福音:实测能准确生成图片、公式、代码、实验数据的AI论文网站
  • 高增益立方升压转换器设计:实现低应力、高效率的DC-DC升压方案
  • 基于蝙蝠侠协议的无人车自组网模块设计与户外实验验证
  • 出版社教学资源网系统的开发
  • 从零开发游戏需要学习的c#模块,第二十六章(多种敌人与基础 AI)
  • TVA现阶段快速进入的五大核心应用场景
  • 2025-2026年发动机缸盖工厂推荐:十大排行专业评测加工精度案例价格 - 品牌推荐
  • 保姆级教程:用ROS的navigation和move_base让小车自己跑起来(附避坑指南)
  • 5G网络基石:从APN到DNN的演进与核心配置解析
  • 异构加速器上并行FFT算法设计与性能优化实践
  • (良心整理)亲测靠谱的AI论文网站,毕业党收藏备用
  • 远程控制哪家稳?地铁高铁酒店WiFi实测,ToDesk弱网优化最强
  • 学术写作效率突破!2026全能型AI论文软件精选指南
  • AI智能体视觉开启人工智能时代新纪元
  • Unity手游开发:用Joystick Pack插件5分钟搞定虚拟摇杆,适配移动端触屏操作
  • HETI架构与堆叠寄存器文件:硬件加速中断上下文切换的嵌入式实时系统优化
  • 从零开发游戏需要学习的c#模块,第二十七章(远程攻击 —— 发射子弹)
  • 【仅限首批500家企业获取】ChatGPT客服话术智能诊断工具包(含话术熵值分析器+合规风险热力图+客户情绪拐点预测模型)
  • 量子网络全栈协同设计:从异构互联到可扩展架构的工程实践
  • 2025-2026年发动机缸盖工厂推荐:五大排行产品专业评测自动化产线防气孔缺陷注意事项 - 品牌推荐
  • 从一次偶发性RST探秘TCP协议栈与NAT的隐秘冲突
  • 智能制造的关键入口:从传统视觉到AI智能体视觉(系列)
  • 第一篇:为什么多个 Flow collect 必须 launch?——一篇讲透 Android 协程生命周期
  • SRT除法器性能优化:Skip-Zero策略的原理、实现与Chisel实践
  • 迭代扰动粒子滤波:突破重采样瓶颈,实现并行化贝叶斯状态估计
  • AIBOX-1684X系统固件升级入门教程
  • ChatGPT产品描述生成失效真相(90%团队踩中的5个认知陷阱)