当前位置: 首页 > news >正文

用Python玩转赌徒问题:手把手教你实现MDP的两种经典算法(附完整代码)

用Python玩转赌徒问题:手把手教你实现MDP的两种经典算法(附完整代码)

马尔科夫决策过程(MDP)是强化学习的基础框架之一,而赌徒问题则是理解MDP的绝佳案例。本文将带你从零开始,用Python实现策略迭代和值迭代这两种经典算法,并通过可视化分析不同参数下的策略变化。无论你是想巩固理论知识,还是希望获得可复用的代码模板,这篇文章都能满足你的需求。

1. 环境准备与问题建模

在开始编码前,我们需要明确赌徒问题的数学模型。假设一个赌徒初始有s美元(1≤s≤99),每次可以选择下注1到min(s,100-s)美元。硬币正面朝上的概率为ph,获胜则获得下注金额,失败则失去下注金额。游戏在达到100美元或破产时结束。

首先安装必要的库:

pip install numpy matplotlib seaborn

定义问题参数:

GOAL = 100 # 目标金额 STATES = np.arange(GOAL + 1) # 所有可能状态(0到100) ph = 0.4 # 硬币正面概率 gamma = 1 # 折扣因子

状态值函数初始化时,只有达到目标状态(100)才有奖励1:

state_values = np.zeros(GOAL + 1) state_values[GOAL] = 1.0

2. 策略迭代算法实现

策略迭代分为两个交替进行的阶段:策略评估和策略改进。我们先来看完整的类实现:

class PolicyIteration: def __init__(self, goal=100, proba_h=0.4, theta=1e-9, gamma=1): self.ph = proba_h self.gamma = gamma self.goal = goal self.theta = theta self.states = np.arange(goal + 1) self.state_values = np.zeros(goal + 1) self.state_values[goal] = 1.0 self.policy = np.zeros(goal + 1) # 初始策略全0 self.sweeps_history = [] # 记录每次迭代的值函数 def policy_evaluation(self): while True: old_values = self.state_values.copy() self.sweeps_history.append(old_values) for s in self.states[1:self.goal]: actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 使用当前策略选择动作 current_a = int(self.policy[s]) if current_a == 0 and s < self.goal: # 初始策略为0,需要处理 current_a = actions[0] self.policy[s] = current_a self.state_values[s] = action_returns[actions.tolist().index(current_a)] delta = np.abs(self.state_values - old_values).max() if delta <= self.theta: break def policy_improvement(self): policy_stable = True for s in self.states[1:self.goal]: old_a = self.policy[s] actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 选择回报最大的动作 max_a = actions[np.argmax(np.round(action_returns, 5))] self.policy[s] = max_a if old_a != max_a: policy_stable = False return policy_stable def solve(self): while True: self.policy_evaluation() if self.policy_improvement(): break

关键点说明:

  • policy_evaluation通过迭代更新状态值函数,直到变化小于阈值theta
  • policy_improvement根据当前值函数选择最优动作
  • solve方法交替执行上述两个步骤直到策略稳定

3. 值迭代算法实现

值迭代将策略评估和策略改进合并为一个步骤,直接更新最优值函数:

class ValueIteration: def __init__(self, goal=100, proba_h=0.4, theta=1e-9, gamma=1): self.ph = proba_h self.gamma = gamma self.goal = goal self.theta = theta self.states = np.arange(goal + 1) self.state_values = np.zeros(goal + 1) self.state_values[goal] = 1.0 self.policy = np.zeros(goal + 1) self.sweeps_history = [] def value_iteration(self): while True: old_values = self.state_values.copy() self.sweeps_history.append(old_values) for s in self.states[1:self.goal]: actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 直接取最大值作为新状态值 self.state_values[s] = np.max(action_returns) delta = np.abs(self.state_values - old_values).max() if delta <= self.theta: break def derive_policy(self): for s in self.states[1:self.goal]: actions = np.arange(min(s, self.goal - s)) + 1 action_returns = [] for a in actions: ret = self.ph * (self.gamma * self.state_values[s + a]) + \ (1 - self.ph) * (self.gamma * self.state_values[s - a]) action_returns.append(ret) # 选择最优动作 self.policy[s] = actions[np.argmax(np.round(action_returns, 5))] def solve(self): self.value_iteration() self.derive_policy()

与策略迭代的主要区别:

  • 每次直接更新为最优值(取max),而不是当前策略下的期望值
  • 值收敛后才一次性推导出策略

4. 结果分析与可视化

实现算法后,我们比较ph=0.4和ph=0.55两种情况下的策略差异:

def plot_results(ph, title): # 策略迭代 pi = PolicyIteration(proba_h=ph) pi.solve() # 值迭代 vi = ValueIteration(proba_h=ph) vi.solve() plt.figure(figsize=(12, 8)) # 绘制策略 plt.subplot(2, 2, 1) plt.step(pi.states, pi.policy, where='post') plt.title(f'Policy Iteration (ph={ph})') plt.xlabel('Capital') plt.ylabel('Optimal stake') plt.subplot(2, 2, 2) plt.step(vi.states, vi.policy, where='post') plt.title(f'Value Iteration (ph={ph})') plt.xlabel('Capital') plt.ylabel('Optimal stake') # 绘制值函数 plt.subplot(2, 2, 3) plt.plot(pi.states, pi.state_values) plt.title('State Values (PI)') plt.xlabel('Capital') plt.ylabel('Value estimate') plt.subplot(2, 2, 4) plt.plot(vi.states, vi.state_values) plt.title('State Values (VI)') plt.xlabel('Capital') plt.ylabel('Value estimate') plt.tight_layout() plt.show() plot_results(0.4, "ph=0.4") plot_results(0.55, "ph=0.55")

关键发现:

  1. 当ph=0.4(劣势赌局)时,两种算法都建议保守策略,只在特定资本时下注较大金额
  2. 当ph=0.55(优势赌局)时,最优策略变得更激进,建议更大胆的下注
  3. 值迭代收敛更快,但策略迭代的策略变化过程更平滑

5. 算法对比与工程实践

在实际应用中,两种算法各有优劣:

特性策略迭代值迭代
收敛速度较慢较快
每次迭代计算量较大较小
中间结果可用性每次迭代都有完整策略只有最终策略
实现复杂度较高较低
适合场景需要中间策略/策略变化平缓只需最终结果/快速原型开发

工程优化建议:

  1. 向量化计算:将内部循环改为矩阵运算
# 替代原来的for循环 returns = ph * values[s + actions] + (1 - ph) * values[s - actions]
  1. 并行化:使用多进程处理状态更新
  2. 早期终止:检测策略是否早停滞
  3. 日志记录:保存每次迭代变化用于调试

常见问题解决:

  • 振荡问题:适当减小学习率或增加theta值
  • 收敛慢:检查奖励设置和折扣因子
  • 内存不足:使用稀疏矩阵表示大状态空间
# 示例:带收敛诊断的改进版值迭代 def value_iteration_enhanced(max_iter=1000): for i in range(max_iter): old_values = values.copy() for s in states[1:GOAL]: # ... 更新逻辑 ... delta = np.abs(values - old_values).max() if delta < theta: print(f"Converged at iteration {i}") break elif i % 10 == 0: print(f"Iter {i}, delta={delta:.4f}")

通过这个完整的实现案例,我们不仅掌握了MDP两种基本算法的编程技巧,还深入理解了它们在策略形成上的差异。建议读者尝试修改参数(如ph、gamma)或奖励函数,观察策略如何随之变化,这是巩固MDP概念的最佳方式。

http://www.jsqmd.com/news/918224/

相关文章:

  • 5步解决虚拟机手柄识别难题:DS4Windows虚拟机配置终极指南
  • 基于ESP32的四足机器人:从逆运动学到AI视觉的完整实现
  • 告别ImageNet标注!用DINO+ViT在无标签数据上实现80%+准确率的保姆级复现教程
  • 2026芜湖奢侈品名牌包包名牌手表回收哪家无套路? - 鸿运名品
  • #三清侠# 最近发现一个超有安全感的“新侠客”[特殊字符]
  • Go语言微服务安全与可靠性最佳实践
  • SQLite Viewer终极指南:如何在浏览器中零安装查看和管理SQLite数据库
  • DWG 格式兼容转换的实战应用与价值落地
  • 电力系统潮流计算Python工程包,含VS解决方案与完整源码
  • YOLO训练翻车?可能是你的TXT标注文件‘回炉’没做好!手把手教你TXT转回Labelme JSON
  • 破解免漆木门行业痛点:四稳共赢方法论如何打造高口碑产品? - 资讯纵览
  • 大语言模型如何“认识”你:从原理到个人数字身份监控实践
  • 3DS自制软件管理终极指南:Universal-Updater一键安装与更新完整教程
  • ABB 011865-003 3/8NPT 内外丝 90° 黄铜弯头
  • 【硬件_USB2.0】一文讲透USB2.0硬件工作原理
  • 5/17(3)
  • 基于Arduino的RC遥控车与激光计时系统DIY全攻略
  • 5/16(2)
  • 别再为CFD-POST云图毛刺抓狂了!手把手教你排查Fluent后处理显示异常(附完整流程)
  • 如何5分钟完成Honey Select 2终极汉化去码补丁安装:完整新手指南
  • 2026 中央电教馆美术教育指导教师证书详解|职业前景、报考流程、官方报名渠道推荐、证书含金量等问题一站式解答 - 教育官方推荐官
  • 编写二手闲置精品甄选定价程序,根据成色市场行情,智能给出合理转卖价。
  • ChatGPT求职信优化实战手册(HR内部评分标准首次公开)
  • 换热器哪家强?2026专业换热器选购指南 - 资讯纵览
  • 颠覆性开源气象革命:如何用Swift构建零成本的全球天气API
  • MacOS 运维常用命令大全(超全速查表)
  • 3D视觉赋能新能源补能无人化:自动充电 / 换电 / 加氢场景技术落地解析
  • 基于OpenCV与Mediapipe的手势识别:实现石头剪刀布人机对战
  • 牛顿迭代算法及使用条件
  • Gemini隐私政策不是法律文件,而是信任协议——用可验证隐私(VP)框架重构起草逻辑(含零知识证明集成示例)