当前位置: 首页 > news >正文

用PyTorch手把手实现DDPG算法,搞定OpenAI Gym连续控制任务(附完整代码)

用PyTorch手把手实现DDPG算法,搞定OpenAI Gym连续控制任务

深度确定性策略梯度(DDPG)作为强化学习领域的重要算法,在机器人控制、自动驾驶等连续动作空间场景中展现出独特优势。本文将带您从零开始构建完整的DDPG实现,通过PyTorch框架解决OpenAI Gym中的经典控制问题Pendulum-v0。不同于理论讲解,我们聚焦工程实践中的关键细节,提供可直接运行的代码方案。

1. 环境配置与核心架构

在开始编码前,需要配置基础环境并理解DDPG的双网络架构。Pendulum-v0环境模拟倒立摆控制任务,其状态空间包含摆角的正余弦值和角速度,动作空间为连续扭矩值。

import gym import torch import numpy as np env = gym.make('Pendulum-v0') state_dim = env.observation_space.shape[0] # 状态维度:3 action_dim = env.action_space.shape[0] # 动作维度:1 action_bound = env.action_space.high[0] # 动作范围:[-2.0, 2.0]

DDPG采用Actor-Critic架构,包含四个神经网络:

  • 在线Actor:策略网络,输入状态输出确定性动作
  • 目标Actor:稳定训练的策略网络副本
  • 在线Critic:价值网络,评估状态-动作对的Q值
  • 目标Critic:稳定训练的价值网络副本
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64): super().__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, action_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return torch.tanh(self.fc3(x)) * action_bound

2. 经验回放与噪声探索

DDPG通过经验回放机制打破数据相关性,使用OU噪声实现有效探索。我们实现一个高效的回放缓冲区:

class ReplayBuffer: def __init__(self, capacity): self.buffer = collections.deque(maxlen=capacity) def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): transitions = random.sample(self.buffer, batch_size) return zip(*transitions)

对于连续动作空间的探索,采用Ornstein-Uhlenbeck过程噪声:

class OUNoise: def __init__(self, action_dim, mu=0, theta=0.15, sigma=0.2): self.action_dim = action_dim self.mu = mu self.theta = theta self.sigma = sigma self.reset() def reset(self): self.state = np.ones(self.action_dim) * self.mu def sample(self): dx = self.theta * (self.mu - self.state) dx += self.sigma * np.random.randn(self.action_dim) self.state += dx return self.state

3. 网络训练与软更新机制

DDPG的核心训练流程包含Critic的TD误差最小化和Actor的策略梯度上升:

def update(self, batch): states, actions, rewards, next_states, dones = batch # Critic损失计算 next_actions = self.target_actor(next_states) target_q = self.target_critic(next_states, next_actions) target_q = rewards + (1 - dones) * self.gamma * target_q current_q = self.critic(states, actions) critic_loss = F.mse_loss(current_q, target_q.detach()) # Actor策略优化 actor_loss = -self.critic(states, self.actor(states)).mean() # 网络参数更新 self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # 目标网络软更新 self.soft_update(self.actor, self.target_actor) self.soft_update(self.critic, self.target_critic)

软更新通过参数混合实现稳定训练:

def soft_update(self, local_model, target_model): for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): target_param.data.copy_(self.tau*local_param.data + (1.0-self.tau)*target_param.data)

4. 完整训练流程与性能优化

将各模块整合为完整训练流程,关键参数设置如下:

参数推荐值作用
buffer_size100000经验回放容量
batch_size64训练批大小
gamma0.99折扣因子
tau0.005软更新系数
actor_lr1e-4Actor学习率
critic_lr1e-3Critic学习率

训练循环实现:

def train_agent(env, agent, episodes=1000): returns = [] for episode in range(episodes): state = env.reset() episode_return = 0 noise.reset() while True: action = agent.select_action(state) next_state, reward, done, _ = env.step(action) agent.replay_buffer.add(state, action, reward, next_state, done) if len(agent.replay_buffer) > batch_size: agent.update() state = next_state episode_return += reward if done: break returns.append(episode_return) print(f"Episode {episode}: Return {episode_return:.1f}") return returns

实际训练中常见问题与解决方案:

  1. 训练不稳定

    • 增大回放缓冲区容量
    • 降低学习率
    • 增加目标网络更新频率
  2. 探索不足

    • 调整OU噪声参数
    • 初期采用更大噪声幅度
    • 逐步衰减噪声强度
  3. 收敛速度慢

    • 优化网络结构(增加层宽/深度)
    • 尝试不同的激活函数
    • 调整批归一化策略

5. 实战效果分析与调优建议

在Pendulum-v0环境中,典型的训练曲线呈现三个阶段:

  1. 探索期(0-200回合):回报波动大,智能体随机尝试不同动作
  2. 学习期(200-600回合):回报快速上升,策略明显改善
  3. 稳定期(600+回合):回报趋于稳定,策略接近最优

通过修改网络结构和训练参数可进一步提升性能:

# 更深的网络结构 class DeepCritic(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=256): super().__init__() self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, hidden_dim) self.fc4 = nn.Linear(hidden_dim, 1) def forward(self, x, a): x = torch.cat([x, a], dim=1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) return self.fc4(x)

最终实现的DDPG算法能够在100-200个训练回合内稳定倒立摆,平均回报达到-200以下(原始环境定义倒立垂直向上为0,向下为-1600)。相比离散动作空间的DQN,DDPG在连续控制任务中展现出三大优势:

  1. 动作精度高:可输出连续扭矩值
  2. 训练效率高:不需要离散化动作空间
  3. 策略更平滑:确定性策略避免动作抖动

实际部署时,建议保存训练好的模型参数:

torch.save({ 'actor': actor.state_dict(), 'critic': critic.state_dict(), }, 'ddpg_model.pth')
http://www.jsqmd.com/news/986762/

相关文章:

  • 动手算一算:如何用Python快速估算光纤通信系统的最大传输距离?
  • 安徽2026年中考无缘高中,还有什么办法上大学? - 小张zc
  • 盐城矮脚拿破仑,金吉拉哪家店比较好,2026精选宠物店排行榜推荐 - 谊识预商务
  • Vue2响应式原理详解——简单易理解
  • 中兴交换机堆叠配置保姆级教程:从端口关闭到重启上线的完整流程
  • Placement-Preparation求职全攻略:从简历准备到面试技巧的完整指南
  • STM32CubeMX配置SPI驱动W25Q64,从零到读写测试的保姆级避坑指南
  • 开源大模型2024生产选型实战:推理效率、硬件适配与中文落地
  • 2026液冷系统排液阀源头工厂推荐:液冷管截止阀全品类生产厂家实力解析 - 栗子测评
  • 盐城边牧,法斗,德牧哪家店比较好,2026精选宠物店排行榜推荐 - 谊识预商务
  • 用MATLAB复现四通道麦克风阵列TDOA定位:从数据集构建到双曲线交汇算法实战
  • AI 推广公司哪家好?2026 实测对比 - 新闻快传
  • `javax.xml.validation` 是 Java 标准版(Java SE)中用于 XML 文档验证的核心包
  • 2026年郑州短视频代运营与GEO优化推广服务商深度横评指南 - 企业名录优选推荐
  • 保姆级教程:用STM32F103驱动ST7735屏幕显示高清图片(附Python图片转换脚本)
  • 保姆级教程:用NVIDIA SDK Manager给Jetson Xavier NX刷机,附99%卡住、SSD启动失败等常见问题解决
  • 什么牌子素颜霜最好用?盘点2026好用又自然的素颜霜口碑榜 - 新闻快传
  • MySQL5.7免安装教程
  • 告别虚拟机!用Docker在Mac/Windows上5分钟搞定Oracle 19c开发环境
  • 多项式插值原理与工程实践:从穿点拟合到龙格现象规避
  • REFramework兼容性问题深度解析:5步解决《怪物猎人:荒野》崩溃难题
  • 2026 年 6 月武汉黄金回收|添价收黄金奢侈品回收中心,专业估价诚意出价 - 薛定谔的梨花猫
  • 别再只调参了!深入SENet消融实验,揭秘通道注意力超参数(如压缩比r)的实战影响
  • 从Sort到DeepSORT:我是如何用‘外观特征’解决目标跟踪中ID频繁跳变这个老大难问题的
  • 音乐歌词获取利器:一键解决你的歌词烦恼,高效管理音乐库
  • 告别玄学调参:用ADS负载/源牵引一步步优化你的2400MHz功放效率(附完整Harmonic Balance设置)
  • 告别2003错误:在CentOS 7上为Navicat配置MySQL远程访问的完整指南
  • `javax.xml.rpc.holders` 是 JAX-RPC(Java API for XML-Based RPC)规范中的一个包
  • 构建企业级语音识别系统:Whisper Base英文模型深度解析与实践指南
  • BlazorFluentUI核心组件解析:打造Windows 11风格的Blazor应用