当前位置: 首页 > news >正文

别再死记硬背Sarsa公式了!用Python手搓一个走迷宫AI,5分钟搞懂On-Policy和Q-learning的区别

用Python构建迷宫AI:5分钟可视化Sarsa与Q-learning的本质差异

在咖啡厅里,我常看到学生对着强化学习教材皱眉——那些充满希腊字母的公式和抽象概念,确实容易让人望而生畏。直到有一天,我让学生用代码实现了一个会走迷宫的AI,他们突然恍然大悟:"原来On-Policy和Off-Policy的区别这么直观!" 本文将带你复现这个魔法时刻:不需要死记硬背贝尔曼方程,而是通过编写一个会学习的迷宫探索者,亲眼见证Sarsa和Q-learning在策略选择上的根本差异。

1. 准备迷宫实验室

我们先搭建一个简单的网格世界。想象一个5×5的迷宫,其中(0,0)是起点,(4,4)是终点,某些格子是陷阱(奖励-1),终点有丰厚奖励(+10)。使用numpymatplotlib就能构建这个微型世界:

import numpy as np import matplotlib.pyplot as plt class MazeEnv: def __init__(self): self.size = 5 self.start = (0, 0) self.goal = (4, 4) self.obstacles = [(1, 1), (2, 3), (3, 1)] self.actions = ['up', 'down', 'left', 'right'] def step(self, state, action): x, y = state if action == 'up': x = max(0, x-1) elif action == 'down': x = min(self.size-1, x+1) elif action == 'left': y = max(0, y-1) elif action == 'right': y = min(self.size-1, y+1) new_state = (x, y) if new_state in self.obstacles: return state, -1, True # 撞墙回弹 if new_state == self.goal: return new_state, 10, True return new_state, -0.1, False # 每步小惩罚鼓励快速通关

关键设计细节

  • 每步给予-0.1的奖励,促使AI寻找最短路径
  • 障碍物碰撞会获得-1奖励并保持原地
  • 使用离散动作空间(上/下/左/右)简化问题

2. Sarsa算法实现

Sarsa作为On-Policy算法,其核心特点是"言行一致"——它用当前策略既选择动作也更新Q值。我们用一个SarsaAgent类来实现:

class SarsaAgent: def __init__(self, env, learning_rate=0.1, discount=0.9, epsilon=0.1): self.q_table = np.zeros((env.size, env.size, len(env.actions))) self.lr = learning_rate self.gamma = discount self.epsilon = epsilon self.actions = env.actions def choose_action(self, state, train=True): if train and np.random.random() < self.epsilon: return np.random.choice(len(self.actions)) return np.argmax(self.q_table[state]) def learn(self, state, action, reward, next_state, next_action): current_q = self.q_table[state][action] next_q = self.q_table[next_state][next_action] td_target = reward + self.gamma * next_q self.q_table[state][action] += self.lr * (td_target - current_q)

算法运行流程

  1. 在状态Sₜ根据ε-greedy策略选择动作Aₜ
  2. 执行动作后获得Rₜ₊₁和Sₜ₊₁
  3. 在Sₜ₊₁继续用相同策略选择Aₜ₊₁
  4. 用五元组(Sₜ, Aₜ, Rₜ₊₁, Sₜ₊₁, Aₜ₊₁)更新Q表

观察下面这个训练过程的可视化,你会发现Sarsa的路径往往更加保守:

Episode 1: S→→→↓→→→→G (碰撞2次) Episode 50: S→→↓→→→G Episode 100: S→↓→→G (稳定路径)

3. Q-learning实现对比

Q-learning作为Off-Policy算法,其更新规则允许"说一套做一套"。我们只需修改learn方法:

class QLearningAgent(SarsaAgent): def learn(self, state, action, reward, next_state, _): current_q = self.q_table[state][action] max_next_q = np.max(self.q_table[next_state]) # 关键区别! td_target = reward + self.gamma * max_next_q self.q_table[state][action] += self.lr * (td_target - current_q)

核心差异对比表

特性SarsaQ-learning
策略一致性On-Policy (言行一致)Off-Policy (目标策略≠行为策略)
更新公式使用实际执行的Aₜ₊₁使用max Q值的动作
探索风险会规避危险格子可能靠近危险
收敛性更稳定可能更激进
适用场景高风险环境(如机器人控制)游戏AI等可承受风险场景

4. 可视化对比训练过程

让我们用matplotlib创建动态对比图。以下代码展示两种算法在相同迷宫中的学习轨迹差异:

def plot_comparison(sarsa_paths, qlearn_paths): plt.figure(figsize=(12, 5)) # Sarsa路径绘制 plt.subplot(121) for path in sarsa_paths: plt.plot([p[1] for p in path], [p[0] for p in path], 'b-', alpha=0.1) plt.title("Sarsa (On-Policy) 路径") # Q-learning路径绘制 plt.subplot(122) for path in qlearn_paths: plt.plot([p[1] for p in path], [p[0] for p in path], 'r-', alpha=0.1) plt.title("Q-learning (Off-Policy) 路径")

典型现象观察:

  • Sarsa:早期会绕开障碍物,即使这意味着更长的路径
  • Q-learning:常出现"切角"行为,偶尔会碰到障碍物但最终学会最优路径

5. 高级话题:从表格方法到神经网络

当迷宫扩大到20×20时,Q表格将变得低效。这时可以引入神经网络作为函数逼近器:

import torch import torch.nn as nn class DQN(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 64), nn.ReLU(), nn.Linear(64, output_dim) ) def forward(self, x): return self.net(x) # Sarsa与Q-learning的神经网络实现差异: # Sarsa需要采样下一个动作Aₜ₊₁,而Q-learning直接取max Q值

经验回放的影响

  • Q-learning可以自由使用历史经验
  • Sarsa若使用经验回放,需要确保采样的Aₜ₊₁与当前策略兼容
# 伪代码:Sarsa的经验回放特殊处理 for transition in replay_buffer.sample(): s, a, r, s_next, a_next = transition if policy_changed: # 需要重新采样a_next a_next = current_policy.select_action(s_next) agent.learn(s, a, r, s_next, a_next)

在项目实践中,我发现当环境随机性较低时(如我们的迷宫),即使Sarsa使用经验回放也能良好工作。但在股票交易等高随机性场景,这种近似可能导致问题。

http://www.jsqmd.com/news/913158/

相关文章:

  • 从喷头滴漏到AI节水37%:一个Lindy灌溉集群的30天自动化演进日记(含Prometheus监控看板+告警阈值SOP)
  • 2026年AI写作辅助平台深度评测:6款工具专业水准得分排名
  • 基于Arduino与超声波传感器的高尔夫自动喂球器设计与实现
  • 基于Arduino与BNO055陀螺仪的桌面绘图机器人:从传感器融合到G代码解析
  • 2026年圆盘式过滤器行业评测:核心性能横向对比 - 优质品牌商家
  • 别再傻傻分不清!用Python代码5分钟搞懂机器学习里的min和argmin
  • 用Python和SVM给水质‘看相’:手把手教你从200张水色图到水质分类模型
  • 从HDRi到游戏画面:手把手教你用Blender和Python预处理IBL环境贴图(含代码)
  • 工业防爆监控技术解析与山东区域选型实践
  • Windows开始菜单修复终极指南:三步恢复消失的磁贴
  • Codex 新增“宠物”功能:不只是可爱,而是一个轻量工作状态提醒器
  • 工具使用、代理和 Voyager 论文
  • 93、CAN FD数据链路层核心:帧结构对比与DLC编码革命
  • 别再被多重共线性坑了!用Python的sklearn手把手教你调岭回归的alpha参数
  • 2026年嵌丝道口板TOP5厂商盘点 品质与实力对比 - 优质品牌商家
  • 172 号卡哪个推荐码是官方一级?10000 置顶权限真实解析 - 172号卡
  • 用Python实战贾俊平《统计学》第八章:手把手教你用SciPy搞定假设检验课后题
  • Lindy自动化项目管理:从概念验证到规模化落地的7个关键决策节点(附20年踩坑清单)
  • 第T9周:猫狗识别2
  • 从电容充放电到MOSFET驱动:一个被忽视的‘能量视角’与热设计陷阱
  • 单细胞分析入门:用Scanpy的AnnData对象管理你的数据,从h5ad读写到基础过滤
  • C语言分支和循环总结
  • Harness 中的请求影子复制:用于离线分析
  • 2026年5月更新:浙江老爹鞋制造商业内推荐与趋势解析 - 2026年企业资讯
  • Claude技术债爆发前夜(2024Q2实测预警:87%企业已超临界阈值)
  • 我的Obsidian知识库,现在可以自动剪藏笔记到本地了
  • 【从零开始的JUC并发第四章】:JUC常用工具类
  • 新手也能跑通大模型,Hugging Face 环境配置与模型加载指南
  • 纯小白向|OpenClaw 本地环境搭建,一步一图教学
  • 5分钟掌握VideoDownloadHelper:你的网页视频下载救星