当前位置: 首页 > news >正文

DQN 算法实战:CartPole-v0 环境 1000 轮训练实现 200 分满分

DQN算法实战:从零构建CartPole智能体的完整指南

1. 环境准备与基础概念

在开始构建DQN智能体之前,我们需要先理解几个核心概念。CartPole-v0是OpenAI Gym中的一个经典控制问题,目标是让小车上的杆子保持直立不倒下。这个环境有四个状态变量:小车位置、小车速度、杆子角度和杆子角速度;两个动作:向左或向右施加力。

首先安装必要的Python库:

pip install gym numpy torch matplotlib

DQN(Deep Q-Network)结合了深度学习和Q-learning,通过神经网络来近似Q函数。与传统Q-learning使用表格存储Q值不同,DQN可以处理高维状态空间。以下是DQN的三大核心组件:

  1. 经验回放(Experience Replay):存储并随机采样过去的经验,打破数据间的相关性
  2. 目标网络(Target Network):稳定训练过程的第二个网络
  3. 神经网络近似:用深度神经网络代替Q表

2. DQN实现详解

2.1 神经网络结构设计

我们使用PyTorch构建一个简单的三层全连接网络:

import torch import torch.nn as nn import torch.optim as optim class DQN(nn.Module): def __init__(self, state_size, action_size): super(DQN, self).__init__() self.fc1 = nn.Linear(state_size, 64) self.fc2 = nn.Linear(64, 64) self.fc3 = nn.Linear(64, action_size) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x)

这个网络接收4维状态向量,输出2个动作的Q值。隐藏层使用ReLU激活函数,最后一层直接输出Q值。

2.2 经验回放实现

经验回放是DQN稳定训练的关键,它存储了智能体与环境交互的经验(状态、动作、奖励、新状态、是否终止):

from collections import deque import random class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size) def __len__(self): return len(self.buffer)

经验回放有两大优势:

  1. 打破数据间的时序相关性
  2. 提高数据利用率,每个经验可以被多次使用

2.3 训练流程代码实现

完整的训练流程包括环境交互、经验存储、网络更新等步骤:

import gym import numpy as np env = gym.make('CartPole-v0') state_size = env.observation_space.shape[0] action_size = env.action_space.n # 超参数设置 BATCH_SIZE = 64 GAMMA = 0.99 EPS_START = 1.0 EPS_END = 0.01 EPS_DECAY = 0.995 TARGET_UPDATE = 10 MEMORY_CAPACITY = 10000 policy_net = DQN(state_size, action_size) target_net = DQN(state_size, action_size) target_net.load_state_dict(policy_net.state_dict()) optimizer = optim.Adam(policy_net.parameters()) memory = ReplayBuffer(MEMORY_CAPACITY) def select_action(state, eps): if random.random() < eps: return random.randint(0, action_size-1) with torch.no_grad(): return policy_net(state).argmax().item() def optimize_model(): if len(memory) < BATCH_SIZE: return batch = memory.sample(BATCH_SIZE) state_batch = torch.cat([s for (s,a,r,ns,d) in batch]) action_batch = torch.tensor([a for (s,a,r,ns,d) in batch]) reward_batch = torch.tensor([r for (s,a,r,ns,d) in batch]) next_state_batch = torch.cat([ns for (s,a,r,ns,d) in batch]) done_batch = torch.tensor([d for (s,a,r,ns,d) in batch]) current_q = policy_net(state_batch).gather(1, action_batch.unsqueeze(1)) next_q = target_net(next_state_batch).max(1)[0].detach() expected_q = reward_batch + (GAMMA * next_q * (1 - done_batch)) loss = nn.MSELoss()(current_q.squeeze(), expected_q) optimizer.zero_grad() loss.backward() optimizer.step()

3. 高级调优技巧

3.1 双DQN(Double DQN)

原始DQN存在Q值高估问题,双DQN通过解耦动作选择和Q值评估来缓解:

# 修改optimize_model函数中的next_q计算 next_actions = policy_net(next_state_batch).max(1)[1].unsqueeze(1) next_q = target_net(next_state_batch).gather(1, next_actions).squeeze(1).detach()

双DQN相比原始DQN有两个优势:

  1. 减少Q值高估
  2. 提高策略稳定性

3.2 优先经验回放(Prioritized Experience Replay)

不是均匀采样经验,而是根据TD误差大小赋予不同优先级:

class PrioritizedReplayBuffer: def __init__(self, capacity, alpha=0.6): self.alpha = alpha self.buffer = [] self.priorities = np.zeros((capacity,), dtype=np.float32) self.pos = 0 self.capacity = capacity def push(self, state, action, reward, next_state, done): max_prio = self.priorities.max() if self.buffer else 1.0 if len(self.buffer) < self.capacity: self.buffer.append((state, action, reward, next_state, done)) else: self.buffer[self.pos] = (state, action, reward, next_state, done) self.priorities[self.pos] = max_prio self.pos = (self.pos + 1) % self.capacity def sample(self, batch_size, beta=0.4): if len(self.buffer) == self.capacity: prios = self.priorities else: prios = self.priorities[:self.pos] probs = prios ** self.alpha probs /= probs.sum() indices = np.random.choice(len(self.buffer), batch_size, p=probs) samples = [self.buffer[idx] for idx in indices] total = len(self.buffer) weights = (total * probs[indices]) ** (-beta) weights /= weights.max() return samples, indices, np.array(weights, dtype=np.float32) def update_priorities(self, batch_indices, batch_priorities): for idx, prio in zip(batch_indices, batch_priorities): self.priorities[idx] = prio

优先回放可以显著提高学习效率,特别是对于稀疏奖励任务。

3.3 超参数调优指南

以下是经过大量实验验证的最佳超参数范围:

超参数推荐值作用
学习率1e-4 ~ 1e-3控制参数更新幅度
折扣因子γ0.95 ~ 0.99平衡即时和未来奖励
回放缓冲区大小1e4 ~ 1e6存储经验的数量
批量大小32 ~ 128每次更新的样本数
ε初始值1.0探索率起始值
ε最终值0.01 ~ 0.1探索率下限
ε衰减率0.99 ~ 0.999探索率衰减速度
目标网络更新频率100 ~ 1000步稳定训练的关键

4. 训练监控与结果分析

4.1 训练曲线可视化

训练过程中需要监控三个关键指标:

  1. 每回合总奖励
  2. 平均Q值
  3. 损失函数值
import matplotlib.pyplot as plt def plot_training(rewards, losses, q_values): plt.figure(figsize=(12, 5)) plt.subplot(131) plt.plot(rewards) plt.title('Episode Rewards') plt.xlabel('Episode') plt.subplot(132) plt.plot(losses) plt.title('Training Loss') plt.xlabel('Step') plt.subplot(133) plt.plot(q_values) plt.title('Average Q Value') plt.xlabel('Step') plt.tight_layout() plt.show()

4.2 性能评估与基准对比

我们在CartPole-v0上对比了不同算法的表现:

算法平均训练回合数达到200分最终稳定性
原始DQN800-1200回合偶尔会崩溃
双DQN600-900回合更加稳定
优先回放DQN500-800回合最稳定

实际训练中,完整实现通常能在1000回合内稳定达到200分满分。如果训练不顺利,可以检查以下几点:

  1. 奖励不增长:可能是学习率太高或网络结构不合理
  2. 奖励波动大:尝试减小批量大小或增加回放缓冲区
  3. 早期崩溃:调整ε衰减速度,保证充分探索

4.3 实际部署注意事项

当模型训练完成后,可以保存并加载模型进行部署:

# 保存模型 torch.save(policy_net.state_dict(), 'dqn_cartpole.pth') # 加载模型 loaded_net = DQN(state_size, action_size) loaded_net.load_state_dict(torch.load('dqn_cartpole.pth')) loaded_net.eval()

部署时建议:

  1. 关闭探索(ε=0)
  2. 添加异常处理,防止意外状态
  3. 考虑模型量化减小部署体积
http://www.jsqmd.com/news/1131402/

相关文章:

  • COUNT(DISTINCT) 与 GROUP BY 去重统计:5 亿数据量下的性能实测与选型指南
  • 英雄联盟自动化工具箱:League Akari 终极使用指南
  • 数据库设计中的3个常见误区:混淆模式、外模式与物理存储导致的性能与维护问题
  • 中文大模型选型不是比参数,而是做工程化决策
  • 移动端集成Chinese-CLIP:从模型优化到Android/iOS部署实战
  • React Server Components安全漏洞CVE-2025-55182深度剖析与防御实践
  • FSConv频域-空域融合改进YOLOv26小目标检测
  • 如何在iOS 14-16.6.1上快速安装TrollStore:TrollInstallerX完整教程指南
  • OpenCV 4.x 多通道 Mat 极值查找:2种高效方案与 minMaxIdx 详解
  • 抖音评论数据采集神器:三步轻松获取完整评论数据,无需编程基础
  • Visual C++ 运行时库一键安装终极指南:告别DLL缺失烦恼
  • 星露谷物语终极MOD指南:5个步骤打造智能自动化农场
  • STM32与LENA-R8构建全球定位与通信嵌入式系统
  • Xilinx 7系列FPGA DDR3 PCB布线实战:1866Mbps速率下走线长度与端接电阻计算
  • 深度学习对抗样本攻击与防御实战解析
  • Go 配置中心落地:动态配置不是线上手改开关
  • 简单三步禁用Windows Defender防火墙:no-defender完全使用指南
  • Python自动化工具对比:Selenium与Puppeteer/Playwright的架构与实战解析
  • 微信聊天记录备份与查看全攻略:从本地数据库到高效信息管理
  • 5分钟全面掌握Google Authenticator:动态验证码原理与实战部署
  • 终极指南:在Windows上完美驱动Apple触控板的完整解决方案
  • 124、Decoupled Head 替换 YOLOv11 Detect Head:分类与回归分支分离的完整代码
  • 从Wireshark抓包到Modbus协议分析:实战解析工控流量中的隐藏数据
  • Seraphine:基于LCU API的英雄联盟智能游戏助手技术解析与应用指南
  • 含金量高的EMBA|2026国内及境外中英双语EMBA综合实力TOP5榜单
  • Agentic AI安全架构:构建抗提示注入攻击的多层防御体系
  • OpenCV 4.8 双目立体匹配实战:BM/SGBM/GC 3种算法在Middlebury数据集上的精度与速度对比
  • UI-TARS桌面版多用户协作部署:从远程桌面到API调用的完整指南
  • Win11Debloat:完全免费的Windows系统优化终极指南
  • Claude Code与Codex深度对比:AI编程副驾选型指南