当前位置: 首页 > news >正文

DDPG算法里的‘演员’和‘评论家’到底在吵什么?用Python代码逐行拆解训练过程

DDPG算法里的‘演员’和‘评论家’到底在吵什么?用Python代码逐行拆解训练过程

想象一下,你正在导演一场没有剧本的即兴戏剧。演员(Actor)需要在舞台上即兴发挥,而评论家(Critic)则在台下实时点评。这场戏的特殊之处在于——演员的动作可以精确到毫米级的角度变化,而评论家的打分标准也在不断调整。这就是DDPG(深度确定性策略梯度)算法的核心戏剧冲突。让我们用PyTorch代码作为舞台,揭开这场"表演艺术"背后的技术内幕。

1. 搭建舞台:DDPG的四大角色初始化

任何好戏都需要精心搭建舞台。在DDPG的宇宙里,我们需要先准备好四个关键神经网络:

import torch import torch.nn as nn import torch.optim as optim import numpy as np class Actor(nn.Module): def __init__(self, state_dim, action_dim, max_action): super(Actor, self).__init__() self.layer_1 = nn.Linear(state_dim, 400) self.layer_2 = nn.Linear(400, 300) self.layer_3 = nn.Linear(300, action_dim) self.max_action = max_action def forward(self, state): x = torch.relu(self.layer_1(state)) x = torch.relu(self.layer_2(x)) return self.max_action * torch.tanh(self.layer_3(x)) class Critic(nn.Module): def __init__(self, state_dim, action_dim): super(Critic, self).__init__() self.layer_1 = nn.Linear(state_dim + action_dim, 400) self.layer_2 = nn.Linear(400, 300) self.layer_3 = nn.Linear(300, 1) def forward(self, state, action): x = torch.cat([state, action], dim=1) x = torch.relu(self.layer_1(x)) x = torch.relu(self.layer_2(x)) return self.layer_3(x)

这里有两个关键设计决策值得注意:

  • Actor的输出层使用tanh:将动作限制在[-max_action, max_action]范围内
  • Critic接收状态和动作的拼接:这是Q函数的典型设计,用于评估(state, action)对的价值

四个角色的初始化就像组建剧团:

# 主演员和主评论家 actor = Actor(state_dim, action_dim, max_action) critic = Critic(state_dim, action_dim) # 备用演员和备用评论家(目标网络) target_actor = Actor(state_dim, action_dim, max_action) target_critic = Critic(state_dim, action_dim) # 初始时目标网络与主网络参数相同 target_actor.load_state_dict(actor.state_dict()) target_critic.load_state_dict(critic.state_dict())

2. 排练过程:训练循环中的动态博弈

真正的戏剧性冲突发生在训练循环中。让我们分解一个完整的训练步骤:

2.1 经验收集阶段

def select_action(state, noise): state = torch.FloatTensor(state.reshape(1, -1)) action = actor(state).data.numpy().flatten() return np.clip(action + noise, -max_action, max_action) # 在环境中执行动作并存储经验 next_state, reward, done, _ = env.step(action) replay_buffer.add(state, action, reward, next_state, done)

这里引入的探索噪声就像演员的即兴发挥——在确定性策略中加入随机性,避免表演变得刻板。常见的选择是Ornstein-Uhlenbeck噪声,它能产生时间相关的随机过程,适合物理系统的连续控制。

2.2 批评家的学习时刻

从经验池采样后,Critic开始它的"毒舌点评":

# 计算目标Q值 target_actions = target_actor(next_states) target_q_values = target_critic(next_states, target_actions) targets = rewards + (1 - dones) * gamma * target_q_values # 计算当前Q值估计 current_q_values = critic(states, actions) # Critic损失函数 critic_loss = nn.MSELoss()(current_q_values, targets.detach())

Critic的更新包含三个关键点:

  1. 使用目标网络计算target_q_values保持稳定性
  2. targets.detach()切断梯度回传,防止干扰目标网络
  3. (1 - dones)项处理回合终止时的特殊情况

2.3 演员的自我修养

Actor的更新则更有意思——它试图讨好Critic:

actor_loss = -critic(states, actor(states)).mean()

这个简单的表达式蕴含着深度策略梯度:

  • 通过Critic评估Actor当前策略的表现
  • 负号表示我们要最大化这个评估值
  • 梯度上升转化为损失函数的极小化

2.4 温和的更新:软同步机制

DDPG最精妙的设计在于目标网络的更新方式:

def soft_update(target, source, tau): for target_param, param in zip(target.parameters(), source.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) # 更新目标网络 soft_update(target_actor, actor, tau) soft_update(target_critic, critic, tau)

这种Polyak平均策略(tau通常取0.005)就像老演员缓慢吸收新演员的表演风格,避免突然的风格转变吓到观众。

3. 幕后花絮:关键技巧与调试经验

在实际制作中,有几个幕后技巧决定了演出成败:

3.1 经验回放的秘密配方

class ReplayBuffer: def __init__(self, max_size): self.buffer = [] self.max_size = max_size def add(self, state, action, reward, next_state, done): if len(self.buffer) >= self.max_size: self.buffer.pop(0) self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): indices = np.random.choice(len(self.buffer), batch_size) states, actions, rewards, next_states, dones = zip(*[self.buffer[i] for i in indices]) return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

经验回放的两个关键参数:

  • buffer大小:通常1e5到1e6,太小导致样本相关性高,太大则学习缓慢
  • batch大小:一般从128开始尝试,复杂任务可能需要更大batch

3.2 学习率的舞蹈

Actor和Critic通常需要不同的学习节奏:

actor_optimizer = optim.Adam(actor.parameters(), lr=1e-4) critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

典型配置:

  • Critic学习率是Actor的5-10倍
  • 太高的Actor学习率会导致策略震荡
  • 太低的Critic学习率则使反馈信号滞后

3.3 噪声退火策略

聪明的导演会随着排练进度减少即兴发挥:

def update_noise(noise_scale): noise_scale *= 0.9999 # 指数衰减 return max(noise_scale, 0.1) # 保持最小探索

这种退火策略平衡了:

  • 初期:高噪声促进探索
  • 后期:低噪声利于策略精修

4. 完整演出:Pendulum-v1实例解析

让我们看一个钟摆平衡的具体案例。以下是训练循环的核心代码:

for episode in range(total_episodes): state = env.reset() episode_reward = 0 noise_scale = initial_noise for step in range(max_steps): action = select_action(state, noise_scale * np.random.randn(action_dim)) next_state, reward, done, _ = env.step(action) replay_buffer.add(state, action, reward, next_state, done) state = next_state episode_reward += reward if len(replay_buffer) > batch_size: states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size) # 转换为PyTorch张量 states = torch.FloatTensor(states) actions = torch.FloatTensor(actions) rewards = torch.FloatTensor(rewards).unsqueeze(1) next_states = torch.FloatTensor(next_states) dones = torch.FloatTensor(dones).unsqueeze(1) # Critic更新 critic_optimizer.zero_grad() critic_loss = compute_critic_loss(states, actions, rewards, next_states, dones) critic_loss.backward() critic_optimizer.step() # Actor更新 actor_optimizer.zero_grad() actor_loss = compute_actor_loss(states) actor_loss.backward() actor_optimizer.step() # 软更新目标网络 soft_update(target_actor, actor, tau) soft_update(target_critic, critic, tau) noise_scale = update_noise(noise_scale) print(f"Episode {episode}, Reward: {episode_reward}")

训练过程中常见的现象记录:

训练阶段典型现象解决方案
初期 (0-1k步)奖励随机波动增加噪声规模,增大回放缓冲区
中期 (1k-10k步)偶尔出现高分但不稳定检查Critic损失是否收敛,调整学习率
后期 (>10k步)性能平台期尝试减小噪声,微调网络结构

在Pendulum-v1环境中,成功的训练通常会在约50-100个episode后开始出现稳定的摆动策略,300个episode左右能达到接近最优的性能。

http://www.jsqmd.com/news/918233/

相关文章:

  • 1379份真实中文临床文本,含手术/药物/疾病等六类实体的字符级标注数据
  • 网盘直链下载助手:技术深度解析与实战指南
  • 番茄小说下载器:三步实现离线阅读自由的专业方案
  • Windows 11优化终极指南:5步让你的电脑重获新生
  • 实体门店短视频获客工具前十|选对工具,门店少亏三年冤枉钱!
  • 业绩翻两番:免漆木门经销商的增长秘诀 - 资讯纵览
  • Ubuntu局域网传文件,除了SCP你还可以试试这个:Rsync增量备份实战
  • 终极解决方案:3分钟让魔兽争霸3在现代电脑上完美运行 [特殊字符]
  • 用Python玩转赌徒问题:手把手教你实现MDP的两种经典算法(附完整代码)
  • 5步解决虚拟机手柄识别难题:DS4Windows虚拟机配置终极指南
  • 基于ESP32的四足机器人:从逆运动学到AI视觉的完整实现
  • 告别ImageNet标注!用DINO+ViT在无标签数据上实现80%+准确率的保姆级复现教程
  • 2026芜湖奢侈品名牌包包名牌手表回收哪家无套路? - 鸿运名品
  • #三清侠# 最近发现一个超有安全感的“新侠客”[特殊字符]
  • Go语言微服务安全与可靠性最佳实践
  • SQLite Viewer终极指南:如何在浏览器中零安装查看和管理SQLite数据库
  • DWG 格式兼容转换的实战应用与价值落地
  • 电力系统潮流计算Python工程包,含VS解决方案与完整源码
  • YOLO训练翻车?可能是你的TXT标注文件‘回炉’没做好!手把手教你TXT转回Labelme JSON
  • 破解免漆木门行业痛点:四稳共赢方法论如何打造高口碑产品? - 资讯纵览
  • 大语言模型如何“认识”你:从原理到个人数字身份监控实践
  • 3DS自制软件管理终极指南:Universal-Updater一键安装与更新完整教程
  • ABB 011865-003 3/8NPT 内外丝 90° 黄铜弯头
  • 【硬件_USB2.0】一文讲透USB2.0硬件工作原理
  • 5/17(3)
  • 基于Arduino的RC遥控车与激光计时系统DIY全攻略
  • 5/16(2)
  • 别再为CFD-POST云图毛刺抓狂了!手把手教你排查Fluent后处理显示异常(附完整流程)
  • 如何5分钟完成Honey Select 2终极汉化去码补丁安装:完整新手指南
  • 2026 中央电教馆美术教育指导教师证书详解|职业前景、报考流程、官方报名渠道推荐、证书含金量等问题一站式解答 - 教育官方推荐官