当前位置: 首页 > news >正文

用Python玩转强化学习:从‘赌徒问题’实战理解MDP的策略迭代与价值迭代

用Python玩转强化学习:从‘赌徒问题’实战理解MDP的策略迭代与价值迭代

在强化学习的入门阶段,许多开发者都会被马尔科夫决策过程(MDP)的理论公式所困扰。本文将通过一个经典的"赌徒问题",用Python代码带你直观理解策略迭代(Policy Iteration)和价值迭代(Value Iteration)这两种核心算法。我们将从零开始构建完整的解决方案,并通过可视化对比两种方法的差异。

1. 问题建模与环境搭建

赌徒问题的核心是:一个赌徒每次可以选择下注金额,有p的概率赢得赌注,(1-p)的概率输掉赌注。游戏在赌徒达到100美元或破产时结束。我们需要找到在不同资金状态下最优的下注策略。

首先建立问题的基础框架:

import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm class GamblerEnv: def __init__(self, goal=100, p_h=0.4): self.goal = goal self.p_h = p_h # 硬币正面朝上的概率 self.states = np.arange(goal + 1) # 0到100的所有状态 def get_actions(self, s): """获取当前状态下的可用动作""" return np.arange(1, min(s, self.goal - s) + 1) def step(self, s, a): """执行动作并返回新状态和奖励""" if np.random.random() < self.p_h: s_new = s + a # 赢钱 else: s_new = s - a # 输钱 reward = 1 if s_new == self.goal else 0 done = (s_new == self.goal) or (s_new == 0) return s_new, reward, done

这个环境类封装了赌徒问题的核心逻辑。get_actions方法返回当前状态下允许的下注金额,step方法模拟游戏的一步进展。

2. 策略迭代算法实现

策略迭代分为两个交替进行的阶段:策略评估和策略改进。我们先来看完整的Python实现:

class PolicyIteration: def __init__(self, env, gamma=1.0, theta=1e-8): self.env = env self.gamma = gamma # 折扣因子 self.theta = theta # 收敛阈值 self.V = np.zeros(env.goal + 1) # 值函数初始化 self.V[env.goal] = 1.0 # 达到目标的价值为1 self.policy = np.zeros(env.goal + 1, dtype=int) # 策略初始化 self.history = [] # 记录迭代过程 def policy_evaluation(self): """评估当前策略下的值函数""" while True: delta = 0 old_V = self.V.copy() self.history.append(old_V) for s in self.env.states[1:self.env.goal]: a = self.policy[s] actions = self.env.get_actions(s) if a not in actions: # 确保策略有效 a = actions[0] self.policy[s] = a # 计算期望价值 v = 0 s_win = s + a s_lose = s - a v += self.env.p_h * (0 + self.gamma * old_V[s_win]) v += (1 - self.env.p_h) * (0 + self.gamma * old_V[s_lose]) self.V[s] = v delta = max(delta, abs(v - old_V[s])) if delta < self.theta: break def policy_improvement(self): """基于当前值函数改进策略""" policy_stable = True for s in self.env.states[1:self.env.goal]: old_a = self.policy[s] actions = self.env.get_actions(s) action_values = [] for a in actions: s_win = s + a s_lose = s - a q = self.env.p_h * (0 + self.gamma * self.V[s_win]) q += (1 - self.env.p_h) * (0 + self.gamma * self.V[s_lose]) action_values.append(q) best_a = actions[np.argmax(action_values)] self.policy[s] = best_a if old_a != best_a: policy_stable = False return policy_stable def solve(self, max_iter=100): """执行策略迭代""" for i in range(max_iter): self.policy_evaluation() if self.policy_improvement(): print(f"策略在第{i+1}次迭代后收敛") break

关键实现细节:

  1. policy_evaluation使用迭代法计算当前策略下的状态价值
  2. policy_improvement基于更新后的价值函数选择最优动作
  3. 每次迭代都完整记录值函数变化,便于后续分析

3. 价值迭代算法实现

价值迭代将策略评估和改进合并为一个步骤,直接寻找最优价值函数:

class ValueIteration: def __init__(self, env, gamma=1.0, theta=1e-8): self.env = env self.gamma = gamma self.theta = theta self.V = np.zeros(env.goal + 1) self.V[env.goal] = 1.0 self.policy = np.zeros(env.goal + 1, dtype=int) self.history = [] def value_iteration(self): """执行价值迭代""" while True: delta = 0 old_V = self.V.copy() self.history.append(old_V) for s in self.env.states[1:self.env.goal]: actions = self.env.get_actions(s) max_v = -np.inf for a in actions: s_win = s + a s_lose = s - a v = self.env.p_h * (0 + self.gamma * old_V[s_win]) v += (1 - self.env.p_h) * (0 + self.gamma * old_V[s_lose]) if v > max_v: max_v = v best_a = a self.V[s] = max_v self.policy[s] = best_a delta = max(delta, abs(max_v - old_V[s])) if delta < self.theta: break def solve(self): """执行求解过程""" self.value_iteration()

价值迭代的核心特点是:

  • 每次迭代直接选择最大价值更新
  • 不需要显式的策略评估阶段
  • 最终策略直接从最优价值函数导出

4. 实验分析与可视化对比

现在让我们运行两种算法并比较它们的结果:

# 初始化环境 env = GamblerEnv(goal=100, p_h=0.4) # 策略迭代 pi = PolicyIteration(env) pi.solve() # 价值迭代 vi = ValueIteration(env) vi.solve() # 可视化值函数收敛过程 plt.figure(figsize=(12, 6)) for i, v in enumerate(pi.history[::5]): plt.plot(v[1:100], alpha=0.3, color='blue', label='Policy Iteration' if i == 0 else "") for i, v in enumerate(vi.history[::5]): plt.plot(v[1:100], alpha=0.3, color='red', label='Value Iteration' if i == 0 else "") plt.title("Value Function Convergence") plt.xlabel("Capital") plt.ylabel("Value Estimate") plt.legend() plt.grid(True) plt.show() # 比较最优策略 plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.bar(env.states[1:100], pi.policy[1:100]) plt.title("Policy Iteration - Optimal Strategy") plt.xlabel("Capital") plt.ylabel("Optimal Bet") plt.subplot(1, 2, 2) plt.bar(env.states[1:100], vi.policy[1:100]) plt.title("Value Iteration - Optimal Strategy") plt.xlabel("Capital") plt.ylabel("Optimal Bet") plt.tight_layout() plt.show()

实验结果会显示几个关键发现:

  1. 收敛速度:价值迭代通常收敛更快,因为它每次迭代都直接追求最优
  2. 策略差异:两种方法得到的策略在p=0.4时会有显著不同
  3. 值函数形状:两种方法最终的值函数非常接近,但收敛路径不同

5. 实战技巧与常见问题

在实现MDP算法时,有几个关键点需要注意:

收敛条件设置

  • θ值不宜过小,否则可能导致不必要的迭代
  • 最大迭代次数应设置合理安全值,防止无限循环
# 好的实践示例 theta = 1e-6 # 对于这个问题足够精确 max_iter = 1000 # 安全上限

浮点数精度问题

  • 避免直接比较浮点数,使用np.isclose
  • 适当控制计算精度
# 不推荐 if v1 == v2: ... # 推荐 if np.isclose(v1, v2, atol=1e-8): ...

调试建议

  1. 从小规模问题开始(如goal=10)
  2. 打印中间结果验证计算逻辑
  3. 可视化每次迭代的值函数变化
# 调试打印示例 print(f"Iter {i}: delta={delta:.4f}, V[50]={self.V[50]:.4f}")

性能优化技巧

  • 使用向量化操作代替循环
  • 对重复计算进行缓存
  • 利用稀疏性减少计算量
# 向量化改进示例 actions = np.arange(1, min(s, goal-s)+1) win_values = self.V[s + actions] * self.p_h lose_values = self.V[s - actions] * (1 - self.p_h) action_values = win_values + lose_values

在实际项目中,我发现价值迭代通常更适合计算资源有限的情况,而策略迭代在策略变化剧烈时可能更稳定。当p_h=0.4时,两种方法产生的策略差异尤其值得分析——这反映了在不利条件下寻找稳健策略的复杂性。

http://www.jsqmd.com/news/906834/

相关文章:

  • 别再被Finder骗了!Mac里多出来的那个‘Macintosh HD’到底是什么?APFS卷组与firmlink机制全解析
  • 保姆级教程:在Ubuntu Server 22.04上搞定图形桌面和VNC远程连接(含RealVNC账号注册避坑)
  • 3D打印热床附着力与高温PI胶带应用技术指南
  • 别再只盯着TXOUTCLK了!手把手教你用FPGA的RXOUTCLK(线路恢复时钟)驱动RXUSRCLK
  • 深入UGUI底层:手把手教你用OnPopulateMesh和顶点偏移,实现Image的任意2D变形
  • 一文读懂AI人工智能:从概念到范式,小白也能秒懂
  • Keil µVision编译错误信息缺失的McAfee杀毒软件解决方案
  • 避坑指南:macOS重装/降级时,磁盘工具抹掉选项怎么选?APFS还是Mac OS扩展?
  • 别再乱改权限了!用微软官方AccessChk工具,5分钟排查Windows系统安全漏洞
  • 从‘平均主义’到‘精准加权’:手把手复现阿里DIN模型中的Attention Unit(附PyTorch代码)
  • 新型智慧城市 + 城市大数据应用完整解决方案(架构 + 平台建设 + 落地实践)
  • pdfClaw免登录在线PDF转Word
  • 从‘克莱因四元群’到‘复数旋转’:手把手带你验证两个群是否同构(附Python代码)
  • 鼎讯信通 RM‑1000 高性能无线电综合测试仪:铁路通信电台检测优选
  • 丰城高端全屋定制商家如何选择?
  • 靠谱的门窗安装品牌企业
  • 基于Arduino与MAX7219的复古LED点阵时钟DIY:从硬件选型到外壳制作
  • 别再手动改乱码了!用convmv命令5分钟搞定Linux下整个文件夹的编码转换
  • 家常饮用养生酒,六味地黄酒暖心相伴
  • Linux系统通过stty命令修改串口波特率
  • AI发现潜伏18年的NGINX高危漏洞:CVE-2026-42945完整技术分析
  • Qt 5.7+ 虚拟键盘插件安装与配置全攻略(含Linux/Windows避坑指南)
  • 量子电路模拟:TDVP方法原理与实践优化
  • 2026公考机构深度横评:粉笔、华图、中公哪家强?
  • 免费.brd文件查看器终极指南:OpenBoardView让电路板设计查看如此简单
  • 保姆级教程:在Ubuntu 22.04上挂载VMFS6数据存储,轻松读取ESXi虚拟机文件
  • 从PR调色到Unity渲染:用Post Processing的Color Grading模块打造电影感游戏画面
  • 用Python和YOLOv5给摄像头装上‘尺子’:一个杯子引发的单目测距实战
  • 微波定向耦合器:原理、指标、架构与设计实例
  • 保姆级教程:在Ubuntu 20.04上从源码编译运行Cartographer ROS(含常见错误排查)