当前位置: 首页 > news >正文

Off-Policy Actor-Critic 与重要性采样

Off-Policy Actor-Critic 与重要性采样

目录:

  1. Importance sampling
  2. The off policy gradient theorem
  3. Algorithm description
  4. python 例子

一 Importance sampling

在 off-policy Actor-Critic 方法中,我们通常用一个行为策略(behavior policy)对应分布来生成数据,去优化另一个目标策略(target policy)对应分布。这时,期望的估计就需要用到重要性采样。

考虑一个简单的随机变量

假设在概率分布下:

则:


另一个分布下:

则:

如果我们只有来自的样本,却想估计,可以通过重要性采样:


即:



这里的就是重要性权重

- 如果,权重为 1,无需调整。
- 如果,说明在目标分布中该样本更重要,应升高权重。
- 如果,则降低权重。

在 off-policy Actor-Critic 中,策略梯度会乘以类似的重要性权重,以校正行为策略与目标策略之间的分布差异。


二 The off -policy policy gradient theorem

借助重要性采样技术,我们现在可以给出 off-policy 策略梯度定理。假设 β 是一个行为策略(behavior policy)。我们的目标是利用 β生成的样本来学习一个目标策略 π,使其最大化以下指标

其中 dβdβ 是策略 ββ 下的平稳分布,vπvπ 是策略 ππ 下的状态价值函数。该指标的梯度由以下定理给出。

定理 10.1(Off-policy 策略梯度定理)
在折扣因子 γ∈(0,1) 的情况下,J(θ) 的梯度为:

其中状态分布 ρ定义为:

这里

表示从状态 s′ 出发,在策略 π下经过折扣后转移到状态 s 的总概率。

(10.11) 中的梯度与同策略情形下的定理 9.1 类似,但存在两点区别。第一点区别在于引入了重要性权重。第二点区别在于动作的采样分布是 A∼β,而不是 A∼π。因此,我们可以通过跟随行为策略 ββ 生成的动作样本,来近似真实的梯度。该定理的证明见专栏 10.2。


三 Algorithm description

基于 off-policy 策略梯度定理,我们现在可以给出 off-policy actor-critic 算法。由于 off-policy 情形与 on-policy 情形非常相似,我们仅介绍其中的关键步骤。

首先,off-policy 策略梯度对任何额外的基线 (baseline) (具有不变性。具体而言,我们有:



这是因为:



为了减小估计方差,我们可以选择基线为,从而得到:



对应的随机梯度上升算法为:



其中

与 on-policy 情形类似,优势函数 \( q_t(s, a) - v_t(s) \) 可以用 TD 误差替代,即:



于是算法变为:



off-policy actor-critic 算法的实现总结在算法 10.3 中。可以看出,该算法与优势 actor-critic 算法基本相同,唯一的区别在于,在 critic 和 actor 中都额外包含了一个重要性权重。

需要注意的是,除了 actor 之外,critic 也通过重要性采样技术从 on-policy 转换为 off-policy。事实上,重要性采样是一种通用技术,可以应用于基于策略和基于价值的算法。

最后,算法 10.3 可以通过多种方式进行扩展,以融入更多技术,例如资格迹 (eligibility traces) [73]。


四 代码例子

上面算法10.3是教科书中的经典但不实用,实际上很难训练

主要原有

  • 重要性权重方差大,易爆炸

  • 双网络耦合,更新不稳定

  • 行为策略需全覆盖目标策略,条件苛刻

# -*- coding: utf-8 -*- """ 文件名: off_policy_ac.py Off-policy Actor-Critic with Importance Sampling (Algorithm 10.3) 行为策略: ε-greedy (if-else 实现) 环境: CartPole-v1 作者: chengxf2 日期: 2026-06-08 """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import gymnasium as gym import numpy as np import matplotlib.pyplot as plt # ============================= # 1. 网络结构:Actor 与 Critic # ============================= class Actor(nn.Module): """策略网络 π(a|s,θ)""" def __init__(self, state_dim: int, action_dim: int): super(Actor, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, action_dim) ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.0) def forward(self, state: torch.Tensor): logits = self.net(state) prob = F.softmax(logits, dim=-1) return prob class Critic(nn.Module): """价值网络 v(s,w)""" def __init__(self, state_dim: int): super(Critic, self).__init__() self.net = nn.Sequential( nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1) ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.0) def forward(self, state: torch.Tensor): value = self.net(state) return value # ============================= # 2. 智能体 (Agent) # ============================= class Agent: def __init__(self, state_dim: int, action_dim: int, gamma: float = 0.98): self.actor = Actor(state_dim, action_dim) self.critic = Critic(state_dim) self.actor_optim = optim.Adam(self.actor.parameters(), lr=3e-4) self.critic_optim = optim.Adam(self.critic.parameters(), lr=5e-2) self.gamma = gamma self.action_dim = action_dim def select_action_via_behavior_policy(self, state, epsilon: float = 0.2): """ε-greedy 行为策略,返回 (动作, 该动作在β下的概率)""" state_ts = torch.FloatTensor(state).unsqueeze(0) with torch.no_grad(): pi_probs = self.actor(state_ts).squeeze(0) greedy_action = torch.argmax(pi_probs).item() if np.random.rand() < epsilon: action = np.random.randint(self.action_dim) else: action = greedy_action # 计算 β(a) if action == greedy_action: beta_prob = 1.0 - epsilon + epsilon / self.action_dim else: beta_prob = epsilon / self.action_dim return action, beta_prob def learn(self, state, action, reward, next_state, done, beta_prob): """ 单步更新 (Algorithm 10.3) 返回: (is_ratio, td_error) 用于监控 """ s = torch.FloatTensor(state).unsqueeze(0) next_s = torch.FloatTensor(next_state).unsqueeze(0) r = torch.FloatTensor([reward]) # ----- TD 目标与误差 ----- v_curr = self.critic(s) with torch.no_grad(): v_next = self.critic(next_s) if not done else torch.tensor([[reward]]) td_target = r + self.gamma * v_next td_error = td_target.detach() - v_curr # δ_t #td_error = torch.clamp(td_error, -10.0, 10.0) # ----- 重要性权重 ρ = π(a|s) / β(a|s) ----- pi_probs = self.actor(s).squeeze(0) pi_prob = pi_probs[action] #print(pi_prob, beta_prob) is_ratio = (pi_prob / (beta_prob + 1e-6)).detach() is_ratio = torch.clamp(is_ratio, 0.1, 10.0) # 截断防爆炸 # ----- 1. 更新 Critic (梯度上升) ----- critic_loss = 0.5 * is_ratio * td_error.pow(2) self.critic_optim.zero_grad() critic_loss.backward() #torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0) self.critic_optim.step() # ----- 2. 更新 Actor (梯度上升) ----- log_prob = torch.log(pi_prob + 1e-6) actor_loss = -is_ratio * td_error.detach() * log_prob self.actor_optim.zero_grad() #torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1.0) actor_loss.backward() self.actor_optim.step() #print(pi_probs) # 返回标量用于打印(直接取出数值) return is_ratio.item(), td_error.item() # ============================= # 3. 训练主循环 # ============================= def train(episodes: int = 500, render: bool = False): env = gym.make('CartPole-v1', render_mode='human' if render else None) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n agent = Agent(state_dim, action_dim) episode_rewards = [] for ep in range(1, episodes + 1): state, _ = env.reset() episode_reward = 0 done = False while not done: action, beta_prob = agent.select_action_via_behavior_policy(state, epsilon=0.2) next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated # 可选的奖励塑造(加速学习,注释掉即为标准环境) is_ratio, td_error = agent.learn(state, action, reward, next_state, done, beta_prob) state = next_state episode_reward += 1 episode_rewards.append(episode_reward) if ep % 10 == 0: avg_reward = np.mean(episode_rewards[-10:]) # 注意:这里 is_ratio 和 td_error 是最后一次 step 的值,只用于观察 print(f"Episode {ep:3d} | Avg Reward (last 10): {avg_reward:.1f} " f"| is_ratio = {is_ratio:.2f}, td_error = {td_error:.2f} ") env.close() print("Training finished.") # 绘制奖励曲线 plt.figure(figsize=(10, 5)) plt.plot(episode_rewards, alpha=0.6, label='Episode Reward') if len(episode_rewards) >= 10: moving_avg = np.convolve(episode_rewards, np.ones(10)/10, mode='valid') plt.plot(range(9, len(episode_rewards)), moving_avg, 'r', linewidth=2, label='Moving Avg (10)') plt.xlabel('Episode') plt.ylabel('Total Reward') plt.title('Off-policy Actor-Critic (ε-greedy behavior) on CartPole-v1') plt.legend() plt.grid(True) plt.savefig('training_curve.png', dpi=150) plt.show() if __name__ == '__main__': train(episodes=300, render=False)
http://www.jsqmd.com/news/980911/

相关文章:

  • Python开发工程师全景解析:岗位职责·各城市薪资·发展前景·高考志愿填报(2026版)
  • 2026如何提升营销岗位的职场能力和核心竞争力
  • 99个免费公共Tracker终极指南:让BT下载速度飙升300%的完整方案
  • Bili23 Downloader 技术解析:B站流媒体架构与API交互机制研究
  • 2024 LLM开发实操指南:本地化部署与RAG微调全链路
  • 黄冈美度天梭+宝玑手表专业回收,26年精选回收店铺排行榜推荐 - 莘州文化
  • LLM代理层消亡史:当模型原生能力让网关退化为透传器
  • 如何在3分钟内为Microsoft Word添加APA第7版参考文献格式?
  • 激活 Change Pointers,让 SAP HR OM 模型只分发变化而不是重发整棵组织树
  • 吉安法穆兰+卡地亚手表专业回收,26年精选回收店铺排行榜推荐 - 莘州文化
  • 计算机毕业设计之django基于python网络安全攻防学习平台
  • 嘉定区配镜深度调研:行业洗牌下,本土品牌如何突围?—— 以嘉艺眼镜公场为例 - 国麟测评
  • 双喜临门|腾视科技杭州总部及深圳子公司乔迁新址,以全新姿态奔赴新征程!
  • 深度解析 Deep-Live-Cam:从原理到实战的 AI 换脸技术指南
  • douyin-downloader:如何通过三层架构设计实现抖音内容的高效批量采集
  • 高校信息安全课用的Python版CA证书系统(带源码+部署指南+全流程截图)
  • 从拍照到识别:一条龙搞定K210物体检测项目(Mx-yolov3 + 自动拍照脚本 + 脱机部署)
  • 终极免费指南:如何用Wand-Enhancer解锁WeMod完整专业功能
  • 别再让雷劈了你的设备!手把手教你为RS485接口选配TVS、GDT和TBU(附IEC标准解读)
  • 5分钟掌握KH Coder:零编程文本挖掘与数据分析的终极指南
  • LLM技术雷达:推理优化、长上下文与评估可信度实战指南
  • 重大升级|大家反映配置最复杂的“会务报名”也变成“点哪儿改哪儿”啦!
  • Ansys仿真许可优化六步法,两家工具自动化程度
  • 83-Java 自动装箱和拆箱
  • Steam成就管理终极教程:如何快速解锁、重置和管理你的Steam成就
  • 莲湖区家政公司选型:防水补漏、通马桶与保姆月嫂护工参考 - 资讯速览
  • Applite:如何让Mac软件管理变得像App Store一样简单?
  • MATLAB实现TDOA+AOA混合定位仿真:含坐标转换、三角解算与误差分析
  • 如何快速掌握Calibre豆瓣元数据插件:面向电子书爱好者的完整解决方案
  • 31851个成语结构化数据集:带拼音、释义、古籍出处和现代例句,支持Excel/文本/数据库直接导入