Cliff Walking环境实战:用Python手把手教你实现Sarsa和Q-Learning(附完整代码)
Cliff Walking环境实战:Python实现Sarsa与Q-Learning算法深度解析
引言:当强化学习遇见悬崖漫步
想象你正站在一个4×12的网格世界起点,右下角是诱人的目标点,但中间却横亘着一道致命的悬崖。每走一步都会消耗体力(奖励-1),跌落悬崖将承受巨大痛苦(奖励-100)。这就是经典的Cliff Walking环境——强化学习领域的"Hello World",完美展示了探索与利用的永恒博弈。
不同于普通的迷宫问题,Cliff Walking的精妙之处在于:
- 安全路径:贴着悬崖上方的长路径(总奖励-13)
- 最优路径:紧贴悬崖边缘的最短路径(总奖励-11)
本文将带您用Python从零实现两种经典算法:保守的Sarsa和冒险的Q-Learning。通过完整的代码示例和可视化分析,您将深入理解:
- 表格型强化学习的核心实现逻辑
- 两种算法在策略选择上的本质差异
- 如何设计高效的训练流程
- 关键参数对算法表现的影响
import numpy as np import matplotlib.pyplot as plt import gym from gym import spaces1. 环境构建:打造自己的Cliff Walking
1.1 自定义Gym环境
我们首先继承gym.Env类创建自定义环境。关键要素包括:
class CliffWalkingEnv(gym.Env): def __init__(self): self.shape = (4, 12) self.start_pos = (3, 0) self.goal_pos = (3, 11) self.cliff = [(3, i) for i in range(1, 11)] self.action_space = spaces.Discrete(4) # 上:0 右:1 下:2 左:3 self.observation_space = spaces.Discrete(self.shape[0] * self.shape[1]) self.reset()1.2 状态转移逻辑
实现核心的_step方法,处理移动逻辑和奖励计算:
def _step(self, action): x, y = self.pos # 移动处理 if action == 0: x = max(x-1, 0) elif action == 1: y = min(y+1, self.shape[1]-1) elif action == 2: x = min(x+1, self.shape[0]-1) elif action == 3: y = max(y-1, 0) self.pos = (x, y) done = False reward = -1 # 终止条件判断 if self.pos in self.cliff: reward = -100 self.reset() elif self.pos == self.goal_pos: done = True reward = 0 return self._get_state(), reward, done, {}1.3 可视化渲染
添加渲染功能直观展示智能体移动:
def _render(self): grid = [['.' for _ in range(self.shape[1])] for _ in range(self.shape[0])] grid[self.goal_pos[0]][self.goal_pos[1]] = 'G' for c in self.cliff: grid[c[0]][c[1]] = 'X' grid[self.pos[0]][self.pos[1]] = 'A' for row in grid: print(' '.join(row)) print()2. Sarsa算法实现:安全第一的保守派
2.1 算法核心原理
Sarsa属于on-policy算法,其更新公式为:
Q(s,a) ← Q(s,a) + α[r + γQ(s',a') - Q(s,a)]其中a'是根据当前策略在s'状态选择的动作,体现"行动-评估"的一致性。
2.2 Python实现细节
我们创建SarsaAgent类封装核心逻辑:
class SarsaAgent: def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1): self.env = env self.alpha = alpha # 学习率 self.gamma = gamma # 折扣因子 self.epsilon = epsilon # 探索率 self.Q = np.zeros((env.observation_space.n, env.action_space.n)) def choose_action(self, state): if np.random.random() < self.epsilon: return self.env.action_space.sample() return np.argmax(self.Q[state])2.3 训练流程剖析
完整的训练循环展示Sarsa的在线学习特性:
def train(env, agent, episodes=500): rewards = [] for _ in range(episodes): state = env.reset() action = agent.choose_action(state) total_reward = 0 done = False while not done: next_state, reward, done, _ = env.step(action) next_action = agent.choose_action(next_state) # Sarsa核心更新 td_target = reward + agent.gamma * agent.Q[next_state][next_action] td_error = td_target - agent.Q[state][action] agent.Q[state][action] += agent.alpha * td_error state, action = next_state, next_action total_reward += reward rewards.append(total_reward) return rewards2.4 结果可视化分析
运行训练后,我们观察到:
- 收敛路径:智能体学会走上方安全路径
- 学习曲线:约200轮后趋于稳定
- 策略特点:避开悬崖边缘,即使路径更长
plt.plot(moving_average(rewards, window=10)) plt.xlabel('Episode') plt.ylabel('Total Reward') plt.title('Sarsa Learning Curve')3. Q-Learning实现:追求最优的冒险家
3.1 算法核心差异
Q-Learning是off-policy算法,其更新公式为:
Q(s,a) ← Q(s,a) + α[r + γmax_a'Q(s',a') - Q(s,a)]关键区别在于使用最大Q值而非实际采取的动作。
3.2 Python实现对比
在agent类中修改更新逻辑:
class QLearningAgent(SarsaAgent): def update(self, state, action, reward, next_state, done): if done: td_target = reward else: td_target = reward + self.gamma * np.max(self.Q[next_state]) td_error = td_target - self.Q[state][action] self.Q[state][action] += self.alpha * td_error3.3 训练流程调整
修改训练循环体现off-policy特性:
def qlearn_train(env, agent, episodes): rewards = [] for _ in range(episodes): state = env.reset() total_reward = 0 done = False while not done: action = agent.choose_action(state) next_state, reward, done, _ = env.step(action) agent.update(state, action, reward, next_state, done) state = next_state total_reward += reward rewards.append(total_reward) return rewards3.4 结果对比分析
与Sarsa相比,Q-Learning表现出:
- 路径选择:学会冒险走悬崖边缘的最短路径
- 收敛速度:通常比Sarsa更快找到高奖励策略
- 风险暴露:偶尔会跌落悬崖导致奖励波动
# 对比两种算法的移动平均奖励 plt.plot(sarsa_ma, label='Sarsa') plt.plot(qlearn_ma, label='Q-Learning') plt.legend()4. 深度解析:算法差异与工程实践
4.1 策略差异的本质
通过价值热力图可以直观理解两种算法的策略差异:
| 状态特征 | Sarsa策略 | Q-Learning策略 |
|---|---|---|
| 靠近悬崖的状态 | 价值较低,避免接近 | 价值较高,敢冒险 |
| 安全路径状态 | 价值梯度均匀 | 价值梯度陡峭 |
def plot_values(agent, title): values = np.max(agent.Q, axis=1).reshape(4,12) plt.imshow(values, cmap='hot') plt.title(title)4.2 超参数调优指南
关键参数的影响实验数据:
| 参数 | 典型范围 | 对Sarsa影响 | 对Q-Learning影响 |
|---|---|---|---|
| 学习率α | 0.01-0.5 | 过大导致震荡 | 可设更大值(如0.5) |
| 探索率ε | 0.05-0.3 | 需要持续探索 | 可随时间衰减 |
| 折扣因子γ | 0.8-0.99 | 较高值(0.95)效果更好 | 适中值(0.9)最佳 |
4.3 实用技巧与陷阱规避
经验技巧:
- 对Q-Learning使用ε衰减:
epsilon = max(0.01, epsilon*0.995) - 初始化Q值为乐观值(如0)鼓励探索
- 监控Q值变化幅度判断收敛
常见陷阱:
- 固定ε导致Q-Learning持续跌落悬崖
- α过大导致Sarsa无法稳定收敛
- 没有定期测试贪婪策略的真实表现
# ε衰减示例 class DecayEpsilonAgent(QLearningAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.initial_epsilon = self.epsilon def choose_action(self, state, episode): self.epsilon = self.initial_epsilon / (1 + episode // 100) return super().choose_action(state)5. 进阶扩展:算法变体与性能提升
5.1 Expected Sarsa实现
结合Sarsa和Q-Learning优点的变体:
class ExpectedSarsaAgent(QLearningAgent): def update(self, state, action, reward, next_state, done): if done: td_target = reward else: policy = np.ones(self.env.action_space.n) * self.epsilon / self.env.action_space.n policy[np.argmax(self.Q[next_state])] += 1 - self.epsilon td_target = reward + self.gamma * np.sum(policy * self.Q[next_state]) self.Q[state][action] += self.alpha * (td_target - self.Q[state][action])5.2 使用经验回放
提升样本效率的改进方案:
class ReplayBuffer: def __init__(self, capacity=1000): self.buffer = collections.deque(maxlen=capacity) def add(self, experience): self.buffer.append(experience) def sample(self, batch_size): return random.sample(self.buffer, batch_size) # 在训练循环中 buffer = ReplayBuffer() for episode in range(episodes): # ...环境交互... buffer.add((state, action, reward, next_state, done)) # 从buffer采样进行更新 batch = buffer.sample(32) for exp in batch: agent.update(*exp)5.3 多步TD学习
平衡MC和TD方法的折中方案:
class NStepSarsaAgent(SarsaAgent): def __init__(self, n_steps=3, *args, **kwargs): super().__init__(*args, **kwargs) self.n_steps = n_steps self.trajectory = [] def update(self, state, action, reward, next_state, done): self.trajectory.append((state, action, reward)) if len(self.trajectory) >= self.n_steps or done: states, actions, rewards = zip(*self.trajectory) G = sum([r * (self.gamma**i) for i, r in enumerate(rewards)]) if not done: G += (self.gamma**self.n_steps) * self.Q[next_state][self.choose_action(next_state)] s, a = states[0], actions[0] self.Q[s][a] += self.alpha * (G - self.Q[s][a]) self.trajectory.pop(0)结语:从Cliff Walking到现实应用
通过这个看似简单的网格世界,我们已经掌握了强化学习最核心的思想精髓。在实际项目中,这些算法经过适当调整可以应用于:
- 机器人路径规划
- 游戏AI策略优化
- 资源调度决策系统
记住,没有放之四海皆准的完美算法——Sarsa的保守稳健和Q-Learning的激进高效各有适用场景。真正的高手懂得根据实际问题特点选择合适的工具,并通过系统化的实验验证找到最佳参数组合。
