当前位置: 首页 > news >正文

别光看理论了!用PyTorch手把手实现一个Actor-Critic模型(附完整代码)

别光看理论了!用PyTorch手把手实现一个Actor-Critic模型(附完整代码)

在强化学习领域,Actor-Critic算法因其结合了策略梯度(Policy Gradient)和价值函数(Value Function)的优点而广受欢迎。然而,许多学习者在掌握了基础理论后,往往在实际编码实现时遇到困难。本文将带你从零开始,使用PyTorch框架完整实现一个Actor-Critic模型,并在经典的CartPole环境中进行验证。

1. 环境准备与基础配置

在开始编码之前,我们需要准备好开发环境。推荐使用Python 3.8+和PyTorch 1.10+版本,这些组合能提供良好的兼容性和性能表现。

首先安装必要的依赖库:

pip install gym torch numpy matplotlib

CartPole是OpenAI Gym中的一个经典控制问题,目标是保持杆子竖直不倒。这个环境非常适合用来测试和验证强化学习算法,因为它的状态空间和动作空间都比较简单,但又能体现算法的有效性。

import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random import matplotlib.pyplot as plt # 创建环境 env = gym.make('CartPole-v1') state_dim = env.observation_space.shape[0] action_dim = env.action_space.n

2. 构建Actor和Critic网络

Actor-Critic算法的核心在于同时维护两个网络:Actor(策略网络)和Critic(价值网络)。这两个网络将共享一些基础特征提取层,但最终输出不同的结果。

2.1 Actor网络实现

Actor网络负责根据当前状态输出动作的概率分布。在离散动作空间中,我们通常使用softmax函数将网络输出转换为概率。

class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_size=128): super(Actor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, action_dim) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return torch.softmax(x, dim=-1)

2.2 Critic网络实现

Critic网络的任务是评估当前状态的价值,为Actor提供改进方向的信号。我们实现一个状态价值函数(V(s))而非动作价值函数(Q(s,a)),这样结构更简单。

class Critic(nn.Module): def __init__(self, state_dim, hidden_size=128): super(Critic, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x

3. 训练算法实现

Actor-Critic算法的训练过程需要同时更新两个网络。下面是完整的训练循环实现,包括经验回放、网络更新和策略评估等关键部分。

3.1 经验回放缓冲区

虽然标准的Actor-Critic是on-policy算法,但加入经验回放可以显著提高样本效率。我们实现一个简单的回放缓冲区。

class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): batch = random.sample(self.buffer, min(len(self.buffer), batch_size)) states, actions, rewards, next_states, dones = zip(*batch) return states, actions, rewards, next_states, dones def __len__(self): return len(self.buffer)

3.2 训练循环实现

下面是完整的训练过程,包括网络初始化、数据收集和参数更新。

def train_actor_critic(env, episodes=1000, batch_size=64, gamma=0.99, actor_lr=1e-3, critic_lr=1e-3, buffer_capacity=10000): # 初始化网络和优化器 actor = Actor(state_dim, action_dim) critic = Critic(state_dim) actor_optim = optim.Adam(actor.parameters(), lr=actor_lr) critic_optim = optim.Adam(critic.parameters(), lr=critic_lr) buffer = ReplayBuffer(buffer_capacity) episode_rewards = [] for episode in range(episodes): state = env.reset() done = False total_reward = 0 while not done: # 选择动作 state_tensor = torch.FloatTensor(state).unsqueeze(0) action_probs = actor(state_tensor) action = torch.multinomial(action_probs, 1).item() # 执行动作 next_state, reward, done, _ = env.step(action) total_reward += reward # 存储经验 buffer.push(state, action, reward, next_state, done) state = next_state # 当缓冲区有足够样本时开始训练 if len(buffer) >= batch_size: states, actions, rewards, next_states, dones = buffer.sample(batch_size) states = torch.FloatTensor(np.array(states)) actions = torch.LongTensor(np.array(actions)).unsqueeze(1) rewards = torch.FloatTensor(np.array(rewards)).unsqueeze(1) next_states = torch.FloatTensor(np.array(next_states)) dones = torch.FloatTensor(np.array(dones)).unsqueeze(1) # 计算Critic损失 current_values = critic(states) next_values = critic(next_states).detach() target_values = rewards + gamma * next_values * (1 - dones) critic_loss = nn.MSELoss()(current_values, target_values) # 更新Critic critic_optim.zero_grad() critic_loss.backward() critic_optim.step() # 计算Actor损失 action_probs = actor(states) selected_action_probs = action_probs.gather(1, actions) advantages = target_values - current_values.detach() actor_loss = -torch.mean(torch.log(selected_action_probs) * advantages) # 更新Actor actor_optim.zero_grad() actor_loss.backward() actor_optim.step() episode_rewards.append(total_reward) # 打印训练进度 if (episode + 1) % 10 == 0: avg_reward = np.mean(episode_rewards[-10:]) print(f"Episode {episode+1}, Avg Reward: {avg_reward:.1f}") return episode_rewards

4. 超参数调优与训练技巧

Actor-Critic算法对超参数比较敏感,合理的参数设置能显著提高训练效果。下面是一些实用的调优建议:

4.1 学习率设置

  • Actor和Critic的学习率通常需要分别设置
  • Critic的学习率一般比Actor稍大(如1e-3 vs 5e-4)
  • 可以使用学习率衰减策略

4.2 折扣因子γ

  • γ控制未来奖励的重要性
  • 对于CartPole,0.95-0.99是合理范围
  • 更长的episode需要更大的γ

4.3 网络结构选择

层数隐藏单元数适用场景
2-364-256简单环境
3-5256-1024复杂环境

4.4 常见问题解决

  1. 训练不稳定

    • 使用梯度裁剪(torch.nn.utils.clip_grad_norm_
    • 降低学习率
    • 增加批量大小
  2. 探索不足

    • 在动作选择时添加ε-greedy策略
    • 增加策略熵正则项
  3. 价值估计偏差

    • 使用目标网络(Target Network)
    • 实现Advantage函数(A2C)

5. 结果可视化与分析

训练完成后,我们需要评估模型的性能并可视化训练过程。下面是结果分析和可视化的代码示例。

def plot_training_results(rewards, window=10): plt.figure(figsize=(12, 6)) # 原始奖励曲线 plt.subplot(1, 2, 1) plt.plot(rewards) plt.title("Raw Training Rewards") plt.xlabel("Episode") plt.ylabel("Reward") # 滑动平均奖励 plt.subplot(1, 2, 2) moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid') plt.plot(moving_avg) plt.title(f"Moving Average (window={window})") plt.xlabel("Episode") plt.ylabel("Average Reward") plt.tight_layout() plt.show() # 运行训练并绘制结果 rewards = train_actor_critic(env, episodes=500) plot_training_results(rewards)

在实际测试中,一个训练良好的Actor-Critic模型应该能在100-200个episode内学会平衡CartPole,并持续保持杆子直立超过195步(Gym中的解决标准)。

6. 进阶改进方向

基础版本的Actor-Critic已经能解决CartPole问题,但对于更复杂的环境,可以考虑以下改进:

6.1 Advantage Actor-Critic (A2C)

A2C通过计算优势函数(Advantage)来减少方差:

# 在训练循环中替换advantages的计算 advantages = target_values - current_values.detach()

6.2 使用多个并行环境

通过同时运行多个环境实例来加速数据收集:

from multiprocessing import Process, Pipe def worker(env_name, conn): env = gym.make(env_name) while True: cmd, data = conn.recv() if cmd == "step": conn.send(env.step(data)) elif cmd == "reset": conn.send(env.reset()) elif cmd == "close": env.close() conn.close() break

6.3 添加熵正则项

鼓励探索,防止策略过早收敛:

# 在actor_loss计算中添加熵项 entropy = -torch.sum(action_probs * torch.log(action_probs), dim=1).mean() actor_loss = actor_loss - 0.01 * entropy

6.4 实现PPO算法

PPO(Proximal Policy Optimization)是Actor-Critic的改进版本,通过限制策略更新幅度来提高稳定性:

# PPO的核心更新步骤 ratio = (new_probs / old_probs).gather(1, actions) surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantages actor_loss = -torch.min(surr1, surr2).mean()
http://www.jsqmd.com/news/680799/

相关文章:

  • 【微软官方未公开的EF Core 10向量陷阱】:为什么AsNoTracking()会导致相似度计算偏移?
  • 拯救者笔记本终极优化指南:Lenovo Legion Toolkit深度探索与实战应用
  • 2026年市面上质量好的中走丝机床品牌推荐榜 - 品牌排行榜
  • 嘉兴庭院花园设计施工公司推荐榜单 - 品牌排行榜
  • 告别低效!用Python+SciPy从零实现多相滤波信道化(附完整代码与避坑指南)
  • Windows PDF处理神器:Poppler零依赖安装指南
  • 异步电路后端实现中的CDC签核:从约束到收敛的实战指南
  • 港科大:揭示AI图文模型存在伪统一性根本缺陷能力突破
  • 2026电压力锅哪个牌子最好最安全?安全与性能深度解析 - 品牌排行榜
  • 复古收音机技术‘复活’记:用2SK241 JFET打造150kHz高灵敏度接收前端
  • Python3 模块精讲:StringIO —— 内存字符串 IO 全解与实战
  • 告别裸机:在S32K3上基于RTOS(如FreeRTOS)构建稳定的FlexCAN多任务通信框架
  • 杭州庭院设计施工公司排行及服务特色解析 - 品牌排行榜
  • 从洪水预测到服务器监控:极值理论EVT在SRE运维中的‘降本增效’实践
  • 杭州屋顶花园设计施工企业推荐及服务解析 - 品牌排行榜
  • 慕尼黑大学团队:AI终于学会像人类一样“推演未来“
  • XUnity.AutoTranslator完整指南:5分钟实现Unity游戏多语言翻译
  • AudioSeal Pixel Studio快速部署:阿里云ECS+NGINX反向代理的公网访问配置
  • 常州国德液压性价比如何,反馈情况好不好 - myqiye
  • XUnity.AutoTranslator深度解析:架构设计与高级应用指南
  • 聊聊2026年鼎成钙业实力怎么样,全国高性价比碳酸钙企业推荐 - 工业品牌热点
  • 康奈尔大学等发现:用更少的题目,反而能训练出更好的AI提示词
  • 二零二六年行业内质量好的线切割机床制造厂家有哪些 - 品牌排行榜
  • 如何用Bili2text将B站视频快速转为文字稿:实用指南
  • fatal error C1007: 无法识别的标志“-typedil”(在“p2”中)
  • 深聊鼎成钙业规模、团队专业性及未来发展趋势,全国客户靠谱之选? - 工业推荐榜
  • 告别数据丢失!用DMA解放你的STM32F103C8T6 CPU,高效处理ADC多通道采样
  • Seraphine终极指南:如何通过智能BP系统快速提升英雄联盟段位
  • 2026年液压机械公司哪家好,分析常州国德液压评价与品牌价值 - mypinpai
  • AI 技术日报 - 2026-04-22