别再死记硬背公式了!用Python手把手实现Model-based强化学习(值迭代/策略迭代对比)
用Python实战Model-based强化学习:值迭代与策略迭代的代码级对比
在强化学习领域,理论推导和数学公式固然重要,但真正理解算法精髓往往需要亲手实现一遍。本文将以Python代码为载体,带你完整实现值迭代和策略迭代算法,并通过可视化对比揭示两种经典方法的本质差异。我们将使用Gym库构建测试环境,用NumPy实现核心逻辑,最终生成直观的策略网格图——整个过程就像在Jupyter Notebook中完成一次算法解剖实验。
1. 环境搭建与问题建模
我们先定义一个简单的网格世界作为测试环境。这个5x5的方格中,智能体需要从任意位置移动到右上角的目标位置(奖励+1),同时避开左下角的陷阱(奖励-1)。移动方向包括上、下、左、右四个基本动作,碰到边界则保持原地。
import numpy as np import matplotlib.pyplot as plt from matplotlib import colors class GridWorld: def __init__(self, size=5): self.size = size self.goal = (0, size-1) # 右上角为目标 self.trap = (size-1, 0) # 左下角为陷阱 self.actions = ['up', 'down', 'left', 'right'] self.action_effects = { 'up': (-1, 0), 'down': (1, 0), 'left': (0, -1), 'right': (0, 1) } def step(self, state, action): """执行动作返回(新状态, 奖励)""" if state == self.goal or state == self.trap: return state, 0 # 终止状态 effect = self.action_effects[action] new_state = (max(0, min(self.size-1, state[0] + effect[0])), max(0, min(self.size-1, state[1] + effect[1]))) reward = 1 if new_state == self.goal else -1 if new_state == self.trap else 0 return new_state, reward这个环境的状态转移是确定性的,符合Model-based强化学习的基本假设。我们设置折扣因子γ=0.9,作为未来奖励的衰减系数。
2. 值迭代算法实现
值迭代的核心思想是直接优化状态价值函数,通过贝尔曼最优方程不断更新价值估计,直到收敛后再提取最优策略。以下是完整实现:
def value_iteration(env, gamma=0.9, theta=1e-6): """值迭代算法实现""" V = np.zeros((env.size, env.size)) # 初始化价值函数 policy = np.empty((env.size, env.size), dtype=object) while True: delta = 0 for i in range(env.size): for j in range(env.size): state = (i, j) if state == env.goal or state == env.trap: continue # 计算每个动作的Q值 q_values = [] for action in env.actions: new_state, reward = env.step(state, action) q = reward + gamma * V[new_state] q_values.append(q) # 更新价值函数 new_value = max(q_values) delta = max(delta, abs(new_value - V[state])) V[state] = new_value if delta < theta: # 收敛判断 break # 从价值函数提取确定性策略 for i in range(env.size): for j in range(env.size): state = (i, j) if state == env.goal: policy[i,j] = 'goal' elif state == env.trap: policy[i,j] = 'trap' else: # 选择Q值最大的动作 q_values = [] for action in env.actions: new_state, reward = env.step(state, action) q = reward + gamma * V[new_state] q_values.append(q) best_action = env.actions[np.argmax(q_values)] policy[i,j] = best_action return V, policy关键实现细节:
- 使用双重循环遍历所有状态
- 对每个状态计算所有可能动作的Q值
- 取最大Q值更新当前状态价值
- 收敛后,选择每个状态下Q值最大的动作构成策略
注意:值迭代中价值函数的更新公式V(s)←maxₐ[r + γV(s')]直接体现了贝尔曼最优方程,这是与策略迭代的本质区别。
3. 策略迭代算法实现
策略迭代采用交替执行策略评估和策略提升的策略,直到策略收敛。下面是分步骤实现:
def policy_evaluation(env, policy, V, gamma=0.9, theta=1e-6): """策略评估:计算给定策略下的价值函数""" while True: delta = 0 for i in range(env.size): for j in range(env.size): state = (i, j) if state == env.goal or state == env.trap: continue action = policy[i,j] new_state, reward = env.step(state, action) new_value = reward + gamma * V[new_state] delta = max(delta, abs(new_value - V[state])) V[state] = new_value if delta < theta: break return V def policy_improvement(env, V, policy, gamma=0.9): """策略提升:基于当前价值函数改进策略""" policy_stable = True for i in range(env.size): for j in range(env.size): state = (i, j) if state == env.goal or state == env.trap: continue old_action = policy[i,j] # 计算所有动作的Q值 q_values = [] for action in env.actions: new_state, reward = env.step(state, action) q = reward + gamma * V[new_state] q_values.append(q) # 选择最优动作 best_action = env.actions[np.argmax(q_values)] policy[i,j] = best_action if old_action != best_action: policy_stable = False return policy, policy_stable def policy_iteration(env, gamma=0.9): """完整的策略迭代算法""" # 初始化随机策略 policy = np.random.choice(env.actions, size=(env.size, env.size)) policy[env.goal] = 'goal' policy[env.trap] = 'trap' V = np.zeros((env.size, env.size)) while True: V = policy_evaluation(env, policy, V, gamma) policy, stable = policy_improvement(env, V, policy, gamma) if stable: break return V, policy策略迭代的特点:
- 策略评估阶段需要完全收敛到当前策略的价值函数
- 策略提升阶段采用贪心策略选择动作
- 整个过程交替进行直到策略不再变化
4. 算法对比与可视化分析
现在我们将两种算法运行在同一环境中,对比它们的收敛过程和最终策略:
# 运行两种算法 env = GridWorld() gamma = 0.9 V_vi, policy_vi = value_iteration(env, gamma) V_pi, policy_pi = policy_iteration(env, gamma) # 可视化价值函数 def plot_values(V, title): fig, ax = plt.subplots() im = ax.imshow(V, cmap='coolwarm') for i in range(V.shape[0]): for j in range(V.shape[1]): text = ax.text(j, i, f"{V[i,j]:.1f}", ha="center", va="center", color="black") ax.set_title(title) fig.tight_layout() plt.show() plot_values(V_vi, "Value Iteration - State Values") plot_values(V_pi, "Policy Iteration - State Values") # 可视化策略 def plot_policy(policy, title): action_map = {'up': '↑', 'down': '↓', 'left': '←', 'right': '→', 'goal': 'G', 'trap': 'X'} policy_symbols = np.vectorize(action_map.get)(policy) fig, ax = plt.subplots() ax.imshow(np.zeros_like(policy, dtype=int), cmap='gray') for i in range(policy.shape[0]): for j in range(policy.shape[1]): text = ax.text(j, i, policy_symbols[i,j], ha="center", va="center", color="red", fontsize=16, fontweight='bold') ax.set_title(title) plt.show() plot_policy(policy_vi, "Value Iteration - Optimal Policy") plot_policy(policy_pi, "Policy Iteration - Optimal Policy")通过可视化输出,我们可以观察到:
| 对比维度 | 值迭代 | 策略迭代 |
|---|---|---|
| 收敛速度 | 较快(直接优化价值函数) | 较慢(需策略评估完全收敛) |
| 计算复杂度 | 每次迭代计算量较大 | 策略评估阶段计算量较大 |
| 中间过程 | 只有价值函数更新 | 显式维护和更新策略 |
| 适用场景 | 大状态空间 | 小状态空间 |
实际运行时会发现,虽然两种算法的收敛路径不同,但最终得到的策略在简单网格世界中通常是相同的。这验证了两种方法在解决MDP问题上的等效性。
5. 工程实践中的调优技巧
在实现过程中,有几个关键点会显著影响算法表现:
状态转移矩阵的处理
# 预计算状态转移矩阵可加速迭代 def build_transition_matrix(env): trans_mat = np.zeros((env.size, env.size, len(env.actions), env.size, env.size)) for i in range(env.size): for j in range(env.size): for a_idx, action in enumerate(env.actions): new_state, reward = env.step((i,j), action) trans_mat[i,j,a_idx,new_state[0],new_state[1]] = 1 return trans_mat收敛性加速技巧
- 值迭代中可以使用Gauss-Seidel更新(就地更新价值函数)
- 策略迭代中可以采用异步更新策略
- 设置合理的收敛阈值θ平衡精度与速度
调试建议
- 打印每次迭代的价值函数变化范数
- 可视化中间策略观察改进过程
- 对边界状态进行特殊检查
- 使用小网格世界验证算法正确性
在更复杂的环境中,可以考虑以下扩展:
- 实现截断策略迭代作为折中方案
- 添加随机动作选择模拟非确定性环境
- 用稀疏矩阵存储大型状态空间
- 并行化状态更新计算
