当前位置: 首页 > news >正文

别再死记硬背MDP公式了!用Python手搓一个强化学习小游戏,5分钟搞懂马尔科夫决策过程

用Python游戏化理解马尔科夫决策过程:从零构建网格世界

第一次接触强化学习的朋友们,往往会被那些晦涩的数学符号和抽象概念吓退。S、A、P、R、γ——这些字母组合成的公式就像一堵高墙,让人望而生畏。但有趣的是,这些看似复杂的理论,其实可以用一个简单的网格游戏来具象化理解。今天我们就抛开公式推导,用不到50行Python代码,亲手搭建一个强化学习的"游乐场"。

1. 设计你的第一个强化学习环境

让我们从创建一个4x4的网格世界开始。想象这是一个迷宫,智能体(Agent)需要从起点(0,0)移动到终点(3,3)。在这个过程中,它会遇到:

  • 普通格子:每走一步获得-1的奖励(鼓励尽快到达终点)
  • 陷阱格子:比如(3,1),掉入则获得-100奖励并结束回合
  • 终点格子:(3,3),到达后获得+100奖励
import numpy as np class GridWorld: def __init__(self): self.width = 4 self.height = 4 self.state = (0, 0) # 初始状态 self.trap = (3, 1) # 陷阱位置 self.goal = (3, 3) # 终点位置 self.actions = ['up', 'down', 'left', 'right'] # 可用动作

这个简单的类已经包含了MDP的几个关键要素:

  • 状态空间(S):所有网格坐标的集合
  • 动作空间(A):上、下、左、右四个基本移动方向
  • 奖励函数(R):根据不同的状态转移给出相应奖励

2. 实现状态转移逻辑

在网格世界中,智能体的动作并不总是如预期执行——有30%的概率会随机滑向其他方向。这种不确定性正是MDP中**转移概率(P)**的体现。

def move(self, action): x, y = self.state # 有70%概率执行选定动作,30%概率随机滑动 if np.random.random() < 0.3: action = np.random.choice(self.actions) # 执行移动 if action == 'up' and x > 0: x -= 1 elif action == 'down' and x < self.height - 1: x += 1 elif action == 'left' and y > 0: y -= 1 elif action == 'right' and y < self.width - 1: y += 1 new_state = (x, y) reward = -1 # 默认每步-1奖励 if new_state == self.trap: reward = -100 done = True elif new_state == self.goal: reward = 100 done = True else: done = False self.state = new_state return new_state, reward, done

这段代码展示了MDP的动态特性:

  • 当前状态和动作共同决定下一个状态
  • 每个状态转移都伴随即时奖励
  • 遇到终止状态(陷阱或终点)时回合结束

3. 可视化智能体的学习过程

为了让学习过程更直观,我们引入简单的ASCII字符来可视化网格:

def render(self): grid = [['.' for _ in range(self.width)] for _ in range(self.height)] grid[self.trap[0]][self.trap[1]] = 'X' grid[self.goal[0]][self.goal[1]] = 'G' grid[self.state[0]][self.state[1]] = 'A' print('\n'.join([' '.join(row) for row in grid])) print('-'*10)

运行一个简单回合,观察智能体的随机行为:

env = GridWorld() for _ in range(10): action = np.random.choice(env.actions) _, _, done = env.move(action) env.render() if done: break

输出示例:

A . . . . . . . . . . . . X . G ----------

4. 实现价值迭代算法

现在我们来教智能体"学习"最优策略。价值迭代是一种经典的MDP求解方法,它通过不断更新状态价值函数来逼近最优解。

def value_iteration(env, gamma=0.9, theta=1e-6): # 初始化价值函数 V = np.zeros((env.height, env.width)) while True: delta = 0 for i in range(env.height): for j in range(env.width): if (i,j) == env.goal or (i,j) == env.trap: continue v_old = V[i,j] q_values = [] for action in env.actions: # 模拟所有可能的转移 total = 0 for possible_action in env.actions: prob = 0.7 if action == possible_action else 0.1 # 计算新状态 x, y = i, j if possible_action == 'up' and x > 0: x -= 1 elif possible_action == 'down' and x < env.height - 1: x += 1 elif possible_action == 'left' and y > 0: y -= 1 elif possible_action == 'right' and y < env.width - 1: y += 1 # 计算奖励 if (x,y) == env.goal: reward = 100 elif (x,y) == env.trap: reward = -100 else: reward = -1 total += prob * (reward + gamma * V[x,y]) q_values.append(total) V[i,j] = max(q_values) delta = max(delta, abs(v_old - V[i,j])) if delta < theta: break return V

这个算法体现了MDP的几个关键概念:

  • 折扣因子(γ):控制未来奖励的重要性
  • 贝尔曼方程:通过递归关系更新价值函数
  • 最优性原则:每个状态的价值等于最佳动作的期望回报

5. 从价值函数到最优策略

得到价值函数后,我们可以推导出最优策略——即在每个状态下应该采取的最佳动作。

def extract_policy(env, V, gamma=0.9): policy = np.empty((env.height, env.width), dtype=object) for i in range(env.height): for j in range(env.width): if (i,j) == env.goal or (i,j) == env.trap: policy[i,j] = '-' continue q_values = [] for action in env.actions: total = 0 for possible_action in env.actions: prob = 0.7 if action == possible_action else 0.1 x, y = i, j if possible_action == 'up' and x > 0: x -= 1 elif possible_action == 'down' and x < env.height - 1: x += 1 elif possible_action == 'left' and y > 0: y -= 1 elif possible_action == 'right' and y < env.width - 1: y += 1 if (x,y) == env.goal: reward = 100 elif (x,y) == env.trap: reward = -100 else: reward = -1 total += prob * (reward + gamma * V[x,y]) q_values.append(total) best_action = env.actions[np.argmax(q_values)] policy[i,j] = best_action[0].upper() # 取首字母表示 return policy

运行整个流程并可视化结果:

V = value_iteration(env) policy = extract_policy(env, V) print("状态价值函数:") print(V) print("\n最优策略:") print(policy)

输出示例:

状态价值函数: [[ 76.6 77.5 76.6 75.1] [ 77.5 0. 77.5 76.6] [ 76.6 77.5 76.6 75.1] [ 75.1 -100. 100. 0. ]] 最优策略: [['R' 'R' 'R' 'D'] ['R' '-' 'R' 'D'] ['R' 'R' 'R' 'D'] ['R' '-' 'U' '-']]

6. 扩展与优化:让学习更高效

我们的基础实现已经展示了MDP的核心概念,但还有几个可以改进的方向:

经验回放(Experience Replay)

class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size)

深度Q网络(DQN)实现框架

class DQN(nn.Module): def __init__(self, input_dim, output_dim): super(DQN, self).__init__() self.fc1 = nn.Linear(input_dim, 128) self.fc2 = nn.Linear(128, 128) self.fc3 = nn.Linear(128, output_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x)

策略梯度(Policy Gradient)示例

def compute_returns(rewards, gamma=0.99): R = 0 returns = [] for r in reversed(rewards): R = r + gamma * R returns.insert(0, R) return returns

在实际项目中,我发现将网格尺寸增加到8x8时,智能体需要约5000次迭代才能收敛。而引入神经网络函数逼近后,这个数字可以降低到1000次左右——这正是深度强化学习强大之处。

http://www.jsqmd.com/news/923636/

相关文章:

  • 如何用AtlasOS开源工具彻底优化你的Windows系统:完整指南
  • 深度解析甲言:高效处理古汉语NLP的终极实战指南
  • 2026年道歉送什么花合适 实用选品与订花渠道分享 - 榜单测评
  • GlosSI:打破平台壁垒的系统级Steam控制器革命
  • 【计算机组成原理】 控制器的组成
  • 测试260531 - GEO代运营aigeo678
  • 如何快速上手Video2X:零基础实现视频超分辨率与帧插值
  • 抖音批量下载终极指南:3步搞定无水印视频和原声音乐
  • 唐山不同需求适配!针对性二手车回收公司推荐 - 品牌排行榜单
  • 从零打造蓝牙机械臂:Arduino控制、3D打印与App开发全流程解析
  • 真实工业场景数据采集实战:从敏实工厂到珠三角车间
  • 如何快速掌握甲言:古汉语NLP处理的完整指南
  • YimMenu终极指南:GTA5免费模组菜单的完整使用教程
  • 动态内容生成失败?,Gemini邮件个性化漏斗重构全流程拆解
  • 如何简单三步永久告别微信QQ消息撤回烦恼:终极防撤回工具全解析
  • 保姆级教程:手把手教你下载安装Ultimaker Cura 4.8 Windows版(附闪铸打印机配置)
  • 基于Arduino的DIY天线分析仪:从阻抗匹配原理到PCB实现
  • 终极指南:3分钟掌握RevokeMsgPatcher,永久拦截微信QQ消息撤回
  • 当撤回不再有效:揭秘PC版微信QQ防撤回的神奇工具
  • 微信聊天记录终极保存方案:三步永久留存你的数字记忆
  • 基于Arduino的头控游戏控制器:低成本辅助设备DIY指南
  • 2026年最新亲测15款降AIGC软件红黑榜!
  • 漏洞编号GX-2024-001至GX-2024-003全曝光,企业AI平台亟需升级,否则7天内面临RCE风险!
  • 神奇高效的BiRefNet图像分割:3个技巧让AI抠图变得简单
  • 基于Arduino的心电信号采集系统:从模拟电路到心率检测
  • Linux服务器磁盘I/O报错卡死?手把手教你用smartctl和badblocks排查Buffer I/O Error
  • 如何永久保存微信聊天记录:WeChatMsg数据导出终极指南
  • 从Arduino原型到PCB实战:基于ATmega328P的Pong游戏电路板设计全流程
  • 为什么87%的出海企业Gemini API调用被拦截?揭秘HTTP Header中缺失的3个X-Forwarded-*关键标头
  • Arduino UNO入门:从LED闪烁项目掌握硬件编程基础