SAC算法里的“熵”到底是啥?用Python代码带你直观理解最大熵强化学习
SAC算法中的"熵":用Python代码揭开强化学习探索之谜
在强化学习的世界里,我们常常教导智能体要"聪明"地行动——选择那些能带来最高奖励的动作。但有趣的是,最先进的算法如SAC(Soft Actor-Critic)却反其道而行之,它鼓励智能体表现得"不那么确定",这就是"熵"的魔力。本文将通过Python代码,带你直观理解这个看似矛盾却极其强大的概念。
1. 熵在强化学习中的直观意义
想象你正在玩一个全新的电子游戏。如果只选择已知能得分的操作,你可能永远发现不了隐藏的彩蛋或更高效的得分方式。这就是传统强化学习的局限——过于"功利"的智能体容易陷入局部最优。而SAC通过引入熵,让智能体保持适度的"好奇心"。
熵的数学定义很简单:对于一个概率分布π(a|s),其熵H(π) = -Σπ(a|s)logπ(a|s)。在代码中,我们可以这样计算:
import numpy as np def compute_entropy(prob_dist): return -np.sum(prob_dist * np.log(prob_dist + 1e-10)) # 加小量避免log(0) # 示例:两个不同的策略在3个动作上的分布 deterministic_policy = np.array([0.9, 0.1, 0.0]) # 确定性强的策略 random_policy = np.array([0.4, 0.3, 0.3]) # 随机性强的策略 print(f"确定性策略熵: {compute_entropy(deterministic_policy):.3f}") print(f"随机策略熵: {compute_entropy(random_policy):.3f}")运行这段代码,你会看到确定性策略的熵值明显更低。SAC的核心思想就是在奖励函数中加入这个熵值作为额外奖励,鼓励策略保持一定的随机性。
2. 构建极简SAC:从网格世界开始
为了直观展示熵的作用,我们实现一个简化版SAC来解决网格世界问题。这个环境包含:
- 5x5网格
- 起点在(0,0),目标在(4,4)
- 某些格子有惩罚(悬崖)
- 动作空间:上、下、左、右
import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random class GridWorld: def __init__(self): self.size = 5 self.goal = (4, 4) self.cliffs = [(1, 2), (2, 2), (3, 2)] self.reset() def reset(self): self.pos = (0, 0) return self.pos def step(self, action): x, y = self.pos if action == 0: y = min(y + 1, self.size - 1) # 上 elif action == 1: y = max(y - 1, 0) # 下 elif action == 2: x = max(x - 1, 0) # 左 elif action == 3: x = min(x + 1, self.size - 1) # 右 self.pos = (x, y) if self.pos in self.cliffs: return self.pos, -10, True if self.pos == self.goal: return self.pos, 10, True return self.pos, -0.1, False # 每步小惩罚鼓励尽快到达目标3. SAC核心组件实现
我们的极简SAC包含三个关键部分:策略网络(Actor)、两个Q网络(Critic)和自动调节的温度参数α。
class PolicyNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64): super().__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc_mean = nn.Linear(hidden_dim, action_dim) self.fc_logstd = nn.Linear(hidden_dim, action_dim) def forward(self, state): x = torch.relu(self.fc1(state)) mean = torch.tanh(self.fc_mean(x)) # 输出在[-1,1]之间 log_std = self.fc_logstd(x) return mean, log_std def sample_action(self, state): mean, log_std = self.forward(state) std = log_std.exp() normal = torch.distributions.Normal(mean, std) action = normal.rsample() # 重参数化采样 log_prob = normal.log_prob(action).sum(-1) return action.tanh(), log_prob # 使用tanh确保动作在[-1,1] class QNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64): super().__init__() self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc_out = nn.Linear(hidden_dim, 1) def forward(self, state, action): x = torch.cat([state, action], dim=-1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc_out(x)4. 熵如何影响智能体行为
温度参数α控制着熵对策略的影响程度。我们可以通过调整α值来观察智能体的行为变化:
def train_sac(env, alpha=0.2, episodes=1000): state_dim = 2 # (x,y)坐标 action_dim = 4 # 上下左右 # 初始化网络 policy = PolicyNetwork(state_dim, action_dim) q1 = QNetwork(state_dim, action_dim) q2 = QNetwork(state_dim, action_dim) # 优化器 policy_optim = optim.Adam(policy.parameters(), lr=3e-4) q_optim = optim.Adam(list(q1.parameters()) + list(q2.parameters()), lr=3e-4) replay_buffer = deque(maxlen=10000) batch_size = 64 for ep in range(episodes): state = env.reset() episode_reward = 0 while True: state_tensor = torch.FloatTensor(state) action, log_prob = policy.sample_action(state_tensor) action_idx = torch.argmax(action).item() # 简化处理 next_state, reward, done = env.step(action_idx) replay_buffer.append((state, action_idx, reward, next_state, done)) # 训练步骤 if len(replay_buffer) > batch_size: batch = random.sample(replay_buffer, batch_size) states, actions, rewards, next_states, dones = zip(*batch) states = torch.FloatTensor(np.array(states)) actions = torch.LongTensor(np.array(actions)) rewards = torch.FloatTensor(np.array(rewards)) next_states = torch.FloatTensor(np.array(next_states)) dones = torch.FloatTensor(np.array(dones)) # Q函数更新 with torch.no_grad(): next_actions, next_log_probs = policy.sample_action(next_states) q1_next = q1(next_states, next_actions) q2_next = q2(next_states, next_actions) q_next = torch.min(q1_next, q2_next) - alpha * next_log_probs target_q = rewards + 0.99 * (1 - dones) * q_next.squeeze() current_q1 = q1(states, actions) current_q2 = q2(states, actions) q1_loss = nn.MSELoss()(current_q1.squeeze(), target_q) q2_loss = nn.MSELoss()(current_q2.squeeze(), target_q) q_loss = q1_loss + q2_loss q_optim.zero_grad() q_loss.backward() q_optim.step() # 策略更新 new_actions, new_log_probs = policy.sample_action(states) q1_new = q1(states, new_actions) q2_new = q2(states, new_actions) q_new = torch.min(q1_new, q2_new) policy_loss = (alpha * new_log_probs - q_new).mean() policy_optim.zero_grad() policy_loss.backward() policy_optim.step() episode_reward += reward state = next_state if done: break if ep % 50 == 0: print(f"Episode {ep}, Reward: {episode_reward:.1f}")5. 温度参数α的调节艺术
α值的选择直接影响智能体的探索行为:
- 高α值(如1.0):智能体像"好奇宝宝",愿意尝试各种路径,即使看起来不是最优的
- 低α值(如0.1):智能体变得"功利",快速锁定看似最优的路径
- 自动调节的α:SAC通常会自动调整α,保持策略熵在一个目标值附近
我们可以通过实验观察不同α值的效果:
# 比较不同α值的效果 for alpha in [0.1, 0.5, 1.0]: print(f"\nTraining with alpha={alpha}") env = GridWorld() train_sac(env, alpha=alpha, episodes=300)在实际运行中,你会发现:
- α=0.1时,智能体倾向于选择最短路径,但可能掉入悬崖
- α=1.0时,智能体会探索更多路径,最终可能发现更安全的路线
- 适中的α值(如0.5)能在探索和利用间取得平衡
6. 可视化熵在训练中的变化
为了更直观理解熵的作用,我们可以记录训练过程中策略熵的变化:
import matplotlib.pyplot as plt def plot_entropy_during_training(): alphas = [0.1, 0.5, 1.0] entropy_records = {alpha: [] for alpha in alphas} for alpha in alphas: env = GridWorld() policy = PolicyNetwork(2, 4) for _ in range(100): state = env.reset() state_tensor = torch.FloatTensor(state) _, log_prob = policy.sample_action(state_tensor) entropy = -log_prob.exp() * log_prob # 近似计算熵 entropy_records[alpha].append(entropy.item()) plt.figure(figsize=(10, 6)) for alpha, entropies in entropy_records.items(): plt.plot(entropies, label=f"α={alpha}") plt.xlabel("Training Steps") plt.ylabel("Policy Entropy") plt.title("Policy Entropy Under Different α Values") plt.legend() plt.show() plot_entropy_during_training()这张图会清晰展示:
- 高α值对应更高的策略熵(更多探索)
- 随着训练进行,所有策略的熵都会逐渐降低(学会利用)
- 但高α值的策略始终保持着更高的随机性
7. SAC在实际问题中的优势
通过这个简化实现,我们可以看到SAC相比传统强化学习算法的优势:
- 更鲁棒的策略学习:不会轻易陷入局部最优
- 自动平衡探索与利用:通过熵正则化自然实现
- 适应复杂环境:在多模态奖励场景下表现优异
例如,在机器人控制中,SAC能让机器人:
- 尝试多种行走方式,而不仅限于一种固定步态
- 在遇到障碍时,能灵活切换策略
- 持续学习新技能而不忘记已有能力
# 实际应用中的SAC通常包含更多优化 class AdvancedSAC: def __init__(self, state_dim, action_dim): # 双Q网络和目标网络 self.q1 = QNetwork(state_dim, action_dim) self.q2 = QNetwork(state_dim, action_dim) self.target_q1 = QNetwork(state_dim, action_dim) self.target_q2 = QNetwork(state_dim, action_dim) # 自动调节的温度参数α self.target_entropy = -action_dim # 常见设置 self.log_alpha = torch.zeros(1, requires_grad=True) self.alpha_optim = optim.Adam([self.log_alpha], lr=3e-4) def update_alpha(self, log_probs): alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() return self.log_alpha.exp().item()这个进阶实现展示了SAC在实际应用中的常见组件,包括目标网络和自动温度调节,这些都是确保算法稳定性的关键。
