当前位置: 首页 > news >正文

深度学习 —— 学习率衰减策略

目录

学习率策略

1. 先说结论:

2. 图例:各种学习率下的图

3. 学习率的方式

4. 公式:

4. 神经网络的训练流程

5. 完整代码示例


学习率策略

模型调优的时候可能才会用

1. 先说结论:

① 学习率小, 梯度下降慢

② 学习率大, 梯度下降快

③ 学习率过大,可能导致梯度震荡或暴涨

2. 图例:各种学习率下的图

lr = [0.01, 0.1, 0.125, 0.2, 0.4]

3. 学习率的方式

① 等间隔学习率衰减. optim.lr_scheduler.StepLR

② 指定间隔学习率衰减 optim.lr_scheduler.MultiStepLR

③ 指数学习率衰减. optim.lr_scheduler.ExponentialLR

上图1 第一行代码

上图2 第二行代码

上图3 第三行代码

# 学习率调度器 """ optimizer: 梯度下降优化器 step_size:间隔周期 gama: 衰减系数 milestones: 指定间隔 调整点。比如 【50,125,160】 那就是51开始 126开始 161开始 """ scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=50,gamma=0.5) # 指定间隔学习率衰减策略 scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[50,125,160],gamma=0.5) # 指数间隔学习率衰减策略 scheduler = optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.95)

④ 周期重启的余弦退火策略.optim.lr_scheduler.CosineAnnealingWarmRestarts

4. 公式:

下面代码 显示上图结果。

通常总训练轮数最少有5~10个周期。重启找到个最优状态,找到个局部最小值。

epoch 总周期长度。比如要5个周期。epoch = 200,T_0 = 40. 周期数 = 200/40

# 批次数 iteration = 10 scheduler = optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.95) # 模型训练 # 遍历批次 for i in range(iteration): # 1.前向传播 y_pred = w * x # 2.计算损失 # 3.梯度清零 # 4.反向传播 # 5.更新参数:w新 = w旧-学习率*梯度 # 6.更新学习率 scheduler.step()

如果: T_mult = 2 .第一个周期点是 50,第二个周期点是150,第三个350.
周期值: 50 100 200

""" optimizer: 梯度下降优化器 T_0: 第一个周期的轮数 eta_min: 最小学习率 最大学习率: 在 optimizer = optim.SGD([w],lr=lr) 里面 上面给的是0.1 T_mult: # 周期倍增因子,默认为1,表示每个周期的轮数相同 """ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, # 优化器对象 T_0=50, # 第一个周期的轮数 eta_min=0, T_mult=1, )

4. 神经网络的训练流程

1.准备数据集
2.构建神经网络模型
3.设置损失函数和优化器,以及学习率调度器
4.模型训练
1.前向传播
2.计算损失
3.梯度清零
4.反向传播
5.更新参数:w新 = w旧-学习率*梯度
6.更新学习率
5.模型测试

5. 完整代码示例

# 导包 import torch import torch.nn as nn import torch.optim as optim # 优化器模块,提供各种优化器对象,比如SGD,Adam import matplotlib.pyplot as plt # 绘图 # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 微软雅黑 plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 # 定义函数,演示 周期重启的余弦退火策略 def demo(): # 0.初始化参数 lr = 0.1 epochs = 200 iteration = 10 # 1.准备数据集 x = torch.tensor([1.0],dtype=torch.float32) y_true = torch.tensor([0.0],dtype=torch.float32) # 2.构建神经网络模型 # 创建张量,模拟网络参数 w = torch.tensor([1.0],dtype=torch.float32,requires_grad=True) # 3.设置损失函数和优化器,以及学习率调度器 # 损失函数 loss_fn = nn.MSELoss() # 优化器 optimizer = optim.SGD([w],lr=lr) # 学习率调度器 """ optimizer: 梯度下降优化器 step_size:间隔周期 gama: 衰减系数 milestones: 指定间隔 调整点。比如 【50,125,160】 那就是51开始 126开始 161开始 """ #scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=50,gamma=0.5) # 指定间隔学习率衰减策略 #scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[50,125,160],gamma=0.5) # 指数间隔学习率衰减策略 #scheduler = optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.95) # 周期重启的余弦退火策略 """ optimizer: 梯度下降优化器 T_0: 第一个周期的轮数 eta_min: 最小学习率 最大学习率: 在 optimizer = optim.SGD([w],lr=lr) 里面 上面给的是0.1 T_mult: # 周期倍增因子,默认为1,表示每个周期的轮数相同 """ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, # 优化器对象 T_0=50, # 第一个周期的轮数 eta_min=0, T_mult=1, ) # 4.模型训练 # 定义列表,记录训练轮数和学习率 lr_list = [] epoch_list = [] for epoch in range(epochs): # 0.获取当前轮数 和 学习率,保存到记录列表中 epoch_list.append(epoch) lr_list.append(scheduler.get_last_lr()[0]) # 遍历批次 for i in range(iteration): # 1.前向传播 y_pred = w * x # 2.计算损失 # 3.梯度清零 # 4.反向传播 # 5.更新参数:w新 = w旧-学习率*梯度 # 6.更新学习率 scheduler.step() # 5.可视化学习率变化 plt.plot(epoch_list,lr_list) plt.title("周期重启的余弦退火策略") plt.xlabel("epoch") plt.ylabel("lr") plt.show() # 测试 if __name__ == '__main__': demo()
http://www.jsqmd.com/news/690429/

相关文章:

  • 别再只会按AutoSet了!手把手教你玩转泰克MSO2000B示波器的触发与采样设置
  • ESP32开发板安装终极指南:从零开始快速上手Arduino-ESP32
  • 新手也能一键部署 OpenClaw,这次真的超级简单
  • nli-MiniLM2-L6-H768惊艳效果:小模型在中文法律文本NLI任务上超越BERT-base
  • 2026年3月头部上海景观设计公司推荐,地产景观设计/屋顶花园设计/私家花园设计,上海景观设计施工团队选哪家 - 品牌推荐师
  • COMSOL声学超材料实证研究
  • “谁弄坏的不好说”:什么时候,信任成了被收割的盲目?
  • 【限时技术白皮书】:Docker 27低代码集成性能压测报告(23类低代码引擎+8大PaaS平台横向对比,仅开放72小时)
  • NVIDIA Audio2Face:AI语音驱动面部动画技术解析
  • 财务外包 vs 自建财务:老板该怎么选?
  • 管道疏通技术选型指南 主流服务品牌实测对比 - 优质品牌商家
  • 四川钢材市场螺纹钢(热轧带肋钢筋)现货批发 - 四川盛世钢联营销中心
  • Figma中文插件终极教程:3分钟让英文界面秒变中文,设计师必备效率神器!
  • 告别误触发!用滞回比较器给电源监控电路加个‘防抖’功能(附RC延时设计)
  • 保姆级教程:当Visio弹出激活向导时,如何一步步排查并卸载错误的密钥
  • 大规模图神经网络训练优化:WholeGraph技术实践
  • 【完整源码+数据集+部署教程】苹果品种分割系统源码&数据集分享 [yolov8-seg-C2f-RFCAConv&yolov8-seg-C2f-DCNV3等50+全套改进创新点发刊_一键训练教程_W
  • Hugging Face开源AI生态:从入门到实战指南
  • MySQL 同步到目标库后,怎么确认数据一致?NineData 的同步与比对方案
  • 2026年Q2国内购房移民机构合规服务能力排行 - 优质品牌商家
  • 别盲目卷算法,普通程序员入局大模型正确姿势
  • LNMP架构里,Nginx和PHP-FPM到底是怎么‘谈恋爱’的?一次讲清FastCGI通信原理与调优
  • ChatGPT与BARD:AI对话模型核心技术对比与应用场景
  • 路灯车租赁品牌可靠性实测 6家主流服务商对比解析 - 优质品牌商家
  • 【限时开源】C++26合约成本审计模板(含Bazel规则、Clang插件、Gnuplot性能热力图脚本):仅开放72小时,专供高实时性系统团队
  • Transformer中线性层与激活函数的核心原理与实践
  • 吊顶里的那根龙骨,后来怎么样了
  • OneDrive彻底卸载方案:3分钟清除Windows云存储残留
  • 【dns】:公共DNS
  • 告别串口不够用:手把手教你用WK2124芯片为树莓派/香橙派扩展4个UART