别再死记硬背Sarsa公式了!用Python手搓一个走迷宫AI,5分钟搞懂On-Policy和Q-learning的区别
用Python构建迷宫AI:5分钟可视化Sarsa与Q-learning的本质差异
在咖啡厅里,我常看到学生对着强化学习教材皱眉——那些充满希腊字母的公式和抽象概念,确实容易让人望而生畏。直到有一天,我让学生用代码实现了一个会走迷宫的AI,他们突然恍然大悟:"原来On-Policy和Off-Policy的区别这么直观!" 本文将带你复现这个魔法时刻:不需要死记硬背贝尔曼方程,而是通过编写一个会学习的迷宫探索者,亲眼见证Sarsa和Q-learning在策略选择上的根本差异。
1. 准备迷宫实验室
我们先搭建一个简单的网格世界。想象一个5×5的迷宫,其中(0,0)是起点,(4,4)是终点,某些格子是陷阱(奖励-1),终点有丰厚奖励(+10)。使用numpy和matplotlib就能构建这个微型世界:
import numpy as np import matplotlib.pyplot as plt class MazeEnv: def __init__(self): self.size = 5 self.start = (0, 0) self.goal = (4, 4) self.obstacles = [(1, 1), (2, 3), (3, 1)] self.actions = ['up', 'down', 'left', 'right'] def step(self, state, action): x, y = state if action == 'up': x = max(0, x-1) elif action == 'down': x = min(self.size-1, x+1) elif action == 'left': y = max(0, y-1) elif action == 'right': y = min(self.size-1, y+1) new_state = (x, y) if new_state in self.obstacles: return state, -1, True # 撞墙回弹 if new_state == self.goal: return new_state, 10, True return new_state, -0.1, False # 每步小惩罚鼓励快速通关关键设计细节:
- 每步给予-0.1的奖励,促使AI寻找最短路径
- 障碍物碰撞会获得-1奖励并保持原地
- 使用离散动作空间(上/下/左/右)简化问题
2. Sarsa算法实现
Sarsa作为On-Policy算法,其核心特点是"言行一致"——它用当前策略既选择动作也更新Q值。我们用一个SarsaAgent类来实现:
class SarsaAgent: def __init__(self, env, learning_rate=0.1, discount=0.9, epsilon=0.1): self.q_table = np.zeros((env.size, env.size, len(env.actions))) self.lr = learning_rate self.gamma = discount self.epsilon = epsilon self.actions = env.actions def choose_action(self, state, train=True): if train and np.random.random() < self.epsilon: return np.random.choice(len(self.actions)) return np.argmax(self.q_table[state]) def learn(self, state, action, reward, next_state, next_action): current_q = self.q_table[state][action] next_q = self.q_table[next_state][next_action] td_target = reward + self.gamma * next_q self.q_table[state][action] += self.lr * (td_target - current_q)算法运行流程:
- 在状态Sₜ根据ε-greedy策略选择动作Aₜ
- 执行动作后获得Rₜ₊₁和Sₜ₊₁
- 在Sₜ₊₁继续用相同策略选择Aₜ₊₁
- 用五元组(Sₜ, Aₜ, Rₜ₊₁, Sₜ₊₁, Aₜ₊₁)更新Q表
观察下面这个训练过程的可视化,你会发现Sarsa的路径往往更加保守:
Episode 1: S→→→↓→→→→G (碰撞2次) Episode 50: S→→↓→→→G Episode 100: S→↓→→G (稳定路径)3. Q-learning实现对比
Q-learning作为Off-Policy算法,其更新规则允许"说一套做一套"。我们只需修改learn方法:
class QLearningAgent(SarsaAgent): def learn(self, state, action, reward, next_state, _): current_q = self.q_table[state][action] max_next_q = np.max(self.q_table[next_state]) # 关键区别! td_target = reward + self.gamma * max_next_q self.q_table[state][action] += self.lr * (td_target - current_q)核心差异对比表:
| 特性 | Sarsa | Q-learning |
|---|---|---|
| 策略一致性 | On-Policy (言行一致) | Off-Policy (目标策略≠行为策略) |
| 更新公式 | 使用实际执行的Aₜ₊₁ | 使用max Q值的动作 |
| 探索风险 | 会规避危险格子 | 可能靠近危险 |
| 收敛性 | 更稳定 | 可能更激进 |
| 适用场景 | 高风险环境(如机器人控制) | 游戏AI等可承受风险场景 |
4. 可视化对比训练过程
让我们用matplotlib创建动态对比图。以下代码展示两种算法在相同迷宫中的学习轨迹差异:
def plot_comparison(sarsa_paths, qlearn_paths): plt.figure(figsize=(12, 5)) # Sarsa路径绘制 plt.subplot(121) for path in sarsa_paths: plt.plot([p[1] for p in path], [p[0] for p in path], 'b-', alpha=0.1) plt.title("Sarsa (On-Policy) 路径") # Q-learning路径绘制 plt.subplot(122) for path in qlearn_paths: plt.plot([p[1] for p in path], [p[0] for p in path], 'r-', alpha=0.1) plt.title("Q-learning (Off-Policy) 路径")典型现象观察:
- Sarsa:早期会绕开障碍物,即使这意味着更长的路径
- Q-learning:常出现"切角"行为,偶尔会碰到障碍物但最终学会最优路径
5. 高级话题:从表格方法到神经网络
当迷宫扩大到20×20时,Q表格将变得低效。这时可以引入神经网络作为函数逼近器:
import torch import torch.nn as nn class DQN(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, output_dim) ) def forward(self, x): return self.net(x) # Sarsa与Q-learning的神经网络实现差异: # Sarsa需要采样下一个动作Aₜ₊₁,而Q-learning直接取max Q值经验回放的影响:
- Q-learning可以自由使用历史经验
- Sarsa若使用经验回放,需要确保采样的Aₜ₊₁与当前策略兼容
# 伪代码:Sarsa的经验回放特殊处理 for transition in replay_buffer.sample(): s, a, r, s_next, a_next = transition if policy_changed: # 需要重新采样a_next a_next = current_policy.select_action(s_next) agent.learn(s, a, r, s_next, a_next)在项目实践中,我发现当环境随机性较低时(如我们的迷宫),即使Sarsa使用经验回放也能良好工作。但在股票交易等高随机性场景,这种近似可能导致问题。
