当前位置: 首页 > news >正文

五分钟入门强化学习PPO(Proximal Policy Optimization)

PPO解决了什么痛点


为什么PPO提高了数据的利用率:

总结

  • 传统 PG 不能多次利用:因为它是“死脑筋”,只能吃最新鲜的数据。数据一旦导致了脑子升级,旧数据就立刻和新脑子八字不合,强行用会导致网络崩溃。

  • PPO 能多次利用:因为它自带了一个“数据折算翻译官(重要性采样)”,并且给自己戴上了“绝对不瞎改(Clip 截断)”的锁链。这使得旧数据在一段时间内保鲜,从而让采样效率瞬间翻了近 10 倍!


PPO算法的执行流程:

要彻底弄懂 PPO,最核心的就是要理清它“采样(玩游戏)”“学习(更新参数)”这两步是如何交替进行的。

传统策略梯度算法是“玩一把,学一次,扔掉数据”;而 PPO 的执行流程可以被形象地概括为:“用旧脑子玩一堆游戏存下来,然后关起门来,用这批录像把新脑子反反复复打磨好几遍,最后拿新脑子去顶替旧脑子。”

PPO 算法的完整执行流程分为以下 4 个标准步骤:


代码实现:

import gym # 导入 OpenAI Gym 库,用于构建强化学习的物理或游戏交互环境 import torch # 导入 PyTorch 深度学习框架的核心库 import torch.nn.functional as F # 导入 PyTorch 的函数式接口(包含激活函数 ReLU, 损失函数 MSE 等) import numpy as np # 导入 NumPy 库,用于高效的数组和矩阵运算 import matplotlib.pyplot as plt # 导入 Matplotlib,用于训练后的数据可视化与画图 import rl_utils # 导入自定义的强化学习工具包(包含计算广义优势估计 GAE 等辅助函数) class PolicyNet(torch.nn.Module): ''' 策略网络 (Actor):根据当前状态,输出每个离散动作的概率分布 ''' def __init__(self, state_dim, hidden_dim, action_dim): super(PolicyNet, self).__init__() # 调用父类 nn.Module 的初始化方法 self.fc1 = torch.nn.Linear(state_dim, hidden_dim) # 定义第一层全连接层:输入为状态维度,输出为隐藏层维度 self.fc2 = torch.nn.Linear(hidden_dim, action_dim) # 定义第二层全连接层:输入为隐藏层维度,输出为动作空间维度 def forward(self, x): x = F.relu(self.fc1(x)) # 输入状态 x 经过第一层全连接层后,使用 ReLU 激活函数增加非线性表达能力 # 经过第二层后,使用 softmax 函数在动作维度(dim=1)上进行归一化,输出各个动作的概率(所有动作概率之和为1) return F.softmax(self.fc2(x), dim=1) class ValueNet(torch.nn.Module): ''' 价值网络 (Critic):评估当前状态有多好,输出一个标量状态价值 V(s) ''' def __init__(self, state_dim, hidden_dim): super(ValueNet, self).__init__() # 调用父类 nn.Module 的初始化方法 self.fc1 = torch.nn.Linear(state_dim, hidden_dim) # 定义第一层全连接层:输入为状态维度,输出为隐藏层维度 self.fc2 = torch.nn.Linear(hidden_dim, 1) # 定义第二层全连接层:输出维度固定为 1,因为只需输出一个标量评分 V(s) def forward(self, x): x = F.relu(self.fc1(x)) # 输入状态 x 经过第一层后,使用 ReLU 激活函数 return self.fc2(x) # 直接输出第二层的计算结果,不需要激活函数(因为价值可以是任意实数,正负皆可) class PPO: ''' PPO算法核心 (Proximal Policy Optimization),采用 Clip (截断) 方式限制策略更新幅度 ''' def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device): # 初始化策略网络 (Actor),并将其部署到指定的计算设备(如 CPU 或 GPU) self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device) # 初始化价值网络 (Critic),并将其部署到指定的计算设备 self.critic = ValueNet(state_dim, hidden_dim).to(device) # 为策略网络定义 Adam 优化器,专门负责更新 Actor 的参数,学习率为 actor_lr self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) # 为价值网络定义 Adam 优化器,专门负责更新 Critic 的参数,学习率为 critic_lr self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) self.gamma = gamma # 奖励的折扣因子,决定对未来长远奖励的重视程度 self.lmbda = lmbda # GAE (广义优势估计) 的平滑衰减参数,用于平衡方差与偏差 self.epochs = epochs # 魔法机制:PPO 允许同一批采集到的序列数据被重复送入网络训练的轮数 self.eps = eps # PPO 截断范围参数 (通常设为 0.2),控制新策略偏离旧策略的最大安全幅度 self.device = device # 记录当前使用的计算设备 def take_action(self, state): ''' 根据当前状态,通过策略网络 (Actor) 采样选择一个动作 ''' # 将传入的 state (通常是 numpy 数组) 转换为 PyTorch 的 float 张量,增加 batch 维度 [1, state_dim],并送入设备 state = torch.tensor([state], dtype=torch.float).to(self.device) probs = self.actor(state) # 状态进行前向传播,获取当前状态下所有动作的概率分布(例如 [0.8, 0.2]) action_dist = torch.distributions.Categorical(probs) # 根据概率分布构建一个类别分布生成器 action = action_dist.sample() # 按概率“掷骰子”进行采样选出具体执行的动作(概率越大的越容易被选中,保证了探索性) return action.item() # 将 PyTorch 张量格式的动作提取为 Python 标准数值(例如 0 或 1),返回给环境执行 def update(self, transition_dict): ''' 核心训练逻辑:使用收集到的一批数据更新 Actor 和 Critic 的参数 ''' # 将这批字典里的状态数据批量转换为 float 张量并送入设备,形状如 [batch_size, state_dim] states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device) # 将动作数据转换为张量,并通过 view(-1, 1) 强行转换成列向量 [batch_size, 1],方便后续做索引操作 actions = torch.tensor(transition_dict['actions']).view(-1, 1).to( self.device) # 将奖励数据转换为 float 张量的列向量 [batch_size, 1] rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device) # 将下一状态数据批量转换为 float 张量 [batch_size, state_dim] next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device) # 将游戏结束标志转换为 float 张量的列向量 [batch_size, 1](游戏结束为 1,未结束为 0) dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device) # 计算 TD 目标值 (TD Target):眼前真实奖励 + 折扣因子 * Critic对下一步的价值预测。 # 绝妙细节:*(1 - dones) 保证了如果游戏在当前步结束,未来价值期望会被强制清零,仅剩眼前奖励。 td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones) # 计算单步 TD 误差:用拿到手的 TD 目标值 减去 Critic 之前对当前状态瞎猜的价值 td_delta = td_target - self.critic(states) # 利用 rl_utils 提供的工具计算 GAE 优势函数(Advantage),通过 lmbda 综合多步 TD 误差,使打分更平滑、方差更低 advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device) # 获取当时采集数据时,“旧策略 (Behavior Policy)” 选择那些动作的对数概率 # 必须加 .detach() 锁定为纯常数!它作为参考基准,绝不能参与后面的梯度反向传播 old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach() # PPO 魔法开始:让这同一批旧数据,在网络里反复迭代学习 epochs 次!(传统策略梯度这么干会直接崩溃) for _ in range(self.epochs): # 获取不断更新的“新策略”在同样状态下,选择同样动作的对数概率 log_probs = torch.log(self.actor(states).gather(1, actions)) # 计算重要性采样比率 (Ratio) = 新策略概率 / 旧策略概率。 # 数学技巧:e^(ln(A) - ln(B)) = A/B。如果 ratio > 1,说明新策略更喜欢这个动作了。 ratio = torch.exp(log_probs - old_log_probs) # PPO 损失的第一部分:未截断的原始目标 (新旧比率 * 优势) surr1 = ratio * advantage # PPO 损失的第二部分:截断目标。使用 clamp 将比率强行锁死在 [1-eps, 1+eps] (如 0.8~1.2) 之间,然后再乘以优势 surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage # 截断 # 计算 Actor 的最终 Loss:取 surr1 和 surr2 中的较小值(取悲观下界),加负号转化为让 PyTorch 最小化的 Loss。 # 这保证了:如果新策略偏离旧策略太远,收益会被切断,从而保护策略的稳定性不“翻车”。 actor_loss = torch.mean(-torch.min(surr1, surr2)) # PPO损失函数 # 计算 Critic 的 Loss:预测的价值与绝对真理 (TD 目标) 之间的 MSE 均方误差。 # td_target 作为靶子,必须加 .detach() 锁定,防止梯度错误地向回传导。 critic_loss = torch.mean( F.mse_loss(self.critic(states), td_target.detach())) # 在反向传播前,必须清空 Actor 优化器中上一步残留的梯度垃圾 self.actor_optimizer.zero_grad() # 清空 Critic 优化器中上一步残留的梯度垃圾 self.critic_optimizer.zero_grad() # 误差反向传播:自动计算出 Actor 网络里每一个参数权重对 Actor Loss 的梯度(责任分锅) actor_loss.backward() # 误差反向传播:自动计算出 Critic 网络参数对 Critic Loss 的梯度 critic_loss.backward() # 优化器执行刀斧手操作:沿着刚才算出的梯度方向,切实修改 Actor 网络的参数 self.actor_optimizer.step() # 切实修改 Critic 网络的参数 self.critic_optimizer.step()
http://www.jsqmd.com/news/925520/

相关文章:

  • 2026PDF转Word免费方案详细教程:软件网页工具一看就会
  • LeetCode 每日一题笔记 日期:2026.05.31 题目:2126. 摧毁小行星
  • 多张图片转pdf的免费工具推荐?2026图片合并转PDF免费方法汇总 - 科技大爆炸
  • 如何永久备份微信聊天记录:WeChatMsg完整本地化解决方案指南
  • Go 语言反射(Reflection)详解
  • 2026全国制造业AI企业应用十大实战服务商深度评测:为何说“人才孵化”才是AI落地的唯一命门? - 速递信息
  • 2026高精度超声波焊接机:解读行业三大核心趋势 - 速递信息
  • 2026手机照片免费转JPG教程!安卓苹果HEIC转JPG不用软件、在线无水印方法
  • 番茄小说永久保存终极指南:3步构建你的个人数字图书馆
  • Redis 常用操作笔记(Go 开发实战)
  • J-Link/J-Trace调试工具在嵌入式开发中的应用与优化
  • 思源宋体终极指南:5分钟掌握免费开源中文字体完整配置方案
  • 别再用Blender了!用Python这5个库搞定3D建模,从数据处理到打印全流程
  • MD怎么转Word?2026年保姆级教程,3步用小程序秒转
  • 全国十大猎头公司实测排行:核心能力对比解析 - 得赢
  • 长三角淘宝网店运营服务商综合能力排行盘点 - 资讯纵览
  • 苏州卫生间楼顶漏水怎么办?厨房、阳台、外墙漏水本地根治方法+靠谱维修指南 - 吉修匠
  • Winhance中文版:一站式Windows系统优化与配置管理解决方案
  • 终极指南:如何快速破解遗忘的压缩包密码
  • 2026EPS转PDF方法大全!Windows/Mac/在线工具及PS/AI转换教程
  • 别再死记命令了!图解华为交换机MAC地址那些事:老化时间、刷新ARP与端口安全详解
  • Go 语言闭包(Closure)详解
  • 淘宝网店运营服务商排行:知名三家机构实力解析 - 速递信息
  • 2026苏州防水哪家好 本地正规补漏公司口碑排名避坑指南 - 吉修匠
  • 2026年全国制造业AI应用实战服务商优选榜单与采购推荐指南 - 速递信息
  • Python集成测试:验证系统协同工作
  • ESP32显示驱动终极指南:打造高效嵌入式图形界面
  • Go 语言匿名函数详解
  • PPT怎么转PDF?2026年手把手教你(小程序/PowerPoint/WPS/在线工具完整方案)
  • 终极炉石传说插件:HsMod完整功能指南与安装教程