别光看理论了!用PyTorch手把手实现一个Actor-Critic模型(附完整代码)
别光看理论了!用PyTorch手把手实现一个Actor-Critic模型(附完整代码)
在强化学习领域,Actor-Critic算法因其结合了策略梯度(Policy Gradient)和价值函数(Value Function)的优点而广受欢迎。然而,许多学习者在掌握了基础理论后,往往在实际编码实现时遇到困难。本文将带你从零开始,使用PyTorch框架完整实现一个Actor-Critic模型,并在经典的CartPole环境中进行验证。
1. 环境准备与基础配置
在开始编码之前,我们需要准备好开发环境。推荐使用Python 3.8+和PyTorch 1.10+版本,这些组合能提供良好的兼容性和性能表现。
首先安装必要的依赖库:
pip install gym torch numpy matplotlibCartPole是OpenAI Gym中的一个经典控制问题,目标是保持杆子竖直不倒。这个环境非常适合用来测试和验证强化学习算法,因为它的状态空间和动作空间都比较简单,但又能体现算法的有效性。
import gym import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random import matplotlib.pyplot as plt # 创建环境 env = gym.make('CartPole-v1') state_dim = env.observation_space.shape[0] action_dim = env.action_space.n2. 构建Actor和Critic网络
Actor-Critic算法的核心在于同时维护两个网络:Actor(策略网络)和Critic(价值网络)。这两个网络将共享一些基础特征提取层,但最终输出不同的结果。
2.1 Actor网络实现
Actor网络负责根据当前状态输出动作的概率分布。在离散动作空间中,我们通常使用softmax函数将网络输出转换为概率。
class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_size=128): super(Actor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, action_dim) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return torch.softmax(x, dim=-1)2.2 Critic网络实现
Critic网络的任务是评估当前状态的价值,为Actor提供改进方向的信号。我们实现一个状态价值函数(V(s))而非动作价值函数(Q(s,a)),这样结构更简单。
class Critic(nn.Module): def __init__(self, state_dim, hidden_size=128): super(Critic, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) def forward(self, state): x = torch.relu(self.fc1(state)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x3. 训练算法实现
Actor-Critic算法的训练过程需要同时更新两个网络。下面是完整的训练循环实现,包括经验回放、网络更新和策略评估等关键部分。
3.1 经验回放缓冲区
虽然标准的Actor-Critic是on-policy算法,但加入经验回放可以显著提高样本效率。我们实现一个简单的回放缓冲区。
class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): batch = random.sample(self.buffer, min(len(self.buffer), batch_size)) states, actions, rewards, next_states, dones = zip(*batch) return states, actions, rewards, next_states, dones def __len__(self): return len(self.buffer)3.2 训练循环实现
下面是完整的训练过程,包括网络初始化、数据收集和参数更新。
def train_actor_critic(env, episodes=1000, batch_size=64, gamma=0.99, actor_lr=1e-3, critic_lr=1e-3, buffer_capacity=10000): # 初始化网络和优化器 actor = Actor(state_dim, action_dim) critic = Critic(state_dim) actor_optim = optim.Adam(actor.parameters(), lr=actor_lr) critic_optim = optim.Adam(critic.parameters(), lr=critic_lr) buffer = ReplayBuffer(buffer_capacity) episode_rewards = [] for episode in range(episodes): state = env.reset() done = False total_reward = 0 while not done: # 选择动作 state_tensor = torch.FloatTensor(state).unsqueeze(0) action_probs = actor(state_tensor) action = torch.multinomial(action_probs, 1).item() # 执行动作 next_state, reward, done, _ = env.step(action) total_reward += reward # 存储经验 buffer.push(state, action, reward, next_state, done) state = next_state # 当缓冲区有足够样本时开始训练 if len(buffer) >= batch_size: states, actions, rewards, next_states, dones = buffer.sample(batch_size) states = torch.FloatTensor(np.array(states)) actions = torch.LongTensor(np.array(actions)).unsqueeze(1) rewards = torch.FloatTensor(np.array(rewards)).unsqueeze(1) next_states = torch.FloatTensor(np.array(next_states)) dones = torch.FloatTensor(np.array(dones)).unsqueeze(1) # 计算Critic损失 current_values = critic(states) next_values = critic(next_states).detach() target_values = rewards + gamma * next_values * (1 - dones) critic_loss = nn.MSELoss()(current_values, target_values) # 更新Critic critic_optim.zero_grad() critic_loss.backward() critic_optim.step() # 计算Actor损失 action_probs = actor(states) selected_action_probs = action_probs.gather(1, actions) advantages = target_values - current_values.detach() actor_loss = -torch.mean(torch.log(selected_action_probs) * advantages) # 更新Actor actor_optim.zero_grad() actor_loss.backward() actor_optim.step() episode_rewards.append(total_reward) # 打印训练进度 if (episode + 1) % 10 == 0: avg_reward = np.mean(episode_rewards[-10:]) print(f"Episode {episode+1}, Avg Reward: {avg_reward:.1f}") return episode_rewards4. 超参数调优与训练技巧
Actor-Critic算法对超参数比较敏感,合理的参数设置能显著提高训练效果。下面是一些实用的调优建议:
4.1 学习率设置
- Actor和Critic的学习率通常需要分别设置
- Critic的学习率一般比Actor稍大(如1e-3 vs 5e-4)
- 可以使用学习率衰减策略
4.2 折扣因子γ
- γ控制未来奖励的重要性
- 对于CartPole,0.95-0.99是合理范围
- 更长的episode需要更大的γ
4.3 网络结构选择
| 层数 | 隐藏单元数 | 适用场景 |
|---|---|---|
| 2-3 | 64-256 | 简单环境 |
| 3-5 | 256-1024 | 复杂环境 |
4.4 常见问题解决
训练不稳定:
- 使用梯度裁剪(
torch.nn.utils.clip_grad_norm_) - 降低学习率
- 增加批量大小
- 使用梯度裁剪(
探索不足:
- 在动作选择时添加ε-greedy策略
- 增加策略熵正则项
价值估计偏差:
- 使用目标网络(Target Network)
- 实现Advantage函数(A2C)
5. 结果可视化与分析
训练完成后,我们需要评估模型的性能并可视化训练过程。下面是结果分析和可视化的代码示例。
def plot_training_results(rewards, window=10): plt.figure(figsize=(12, 6)) # 原始奖励曲线 plt.subplot(1, 2, 1) plt.plot(rewards) plt.title("Raw Training Rewards") plt.xlabel("Episode") plt.ylabel("Reward") # 滑动平均奖励 plt.subplot(1, 2, 2) moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid') plt.plot(moving_avg) plt.title(f"Moving Average (window={window})") plt.xlabel("Episode") plt.ylabel("Average Reward") plt.tight_layout() plt.show() # 运行训练并绘制结果 rewards = train_actor_critic(env, episodes=500) plot_training_results(rewards)在实际测试中,一个训练良好的Actor-Critic模型应该能在100-200个episode内学会平衡CartPole,并持续保持杆子直立超过195步(Gym中的解决标准)。
6. 进阶改进方向
基础版本的Actor-Critic已经能解决CartPole问题,但对于更复杂的环境,可以考虑以下改进:
6.1 Advantage Actor-Critic (A2C)
A2C通过计算优势函数(Advantage)来减少方差:
# 在训练循环中替换advantages的计算 advantages = target_values - current_values.detach()6.2 使用多个并行环境
通过同时运行多个环境实例来加速数据收集:
from multiprocessing import Process, Pipe def worker(env_name, conn): env = gym.make(env_name) while True: cmd, data = conn.recv() if cmd == "step": conn.send(env.step(data)) elif cmd == "reset": conn.send(env.reset()) elif cmd == "close": env.close() conn.close() break6.3 添加熵正则项
鼓励探索,防止策略过早收敛:
# 在actor_loss计算中添加熵项 entropy = -torch.sum(action_probs * torch.log(action_probs), dim=1).mean() actor_loss = actor_loss - 0.01 * entropy6.4 实现PPO算法
PPO(Proximal Policy Optimization)是Actor-Critic的改进版本,通过限制策略更新幅度来提高稳定性:
# PPO的核心更新步骤 ratio = (new_probs / old_probs).gather(1, actions) surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantages actor_loss = -torch.min(surr1, surr2).mean()