用《Flappy Bird》游戏带你搞懂强化学习:从Q-learning到DQN的保姆级实战
用《Flappy Bird》游戏带你搞懂强化学习:从Q-learning到DQN的保姆级实战
还记得2014年那个让人又爱又恨的《Flappy Bird》吗?这只像素小鸟曾让无数玩家抓狂,现在我们将用这个经典游戏作为实验室,带你亲手打造一个会自己玩游戏的AI。这不是普通的编程教程,而是一场从零开始的强化学习探险——不需要高深的数学基础,只要你会写Python,就能在3小时内见证AI从"菜鸟"到"高手"的进化历程。
1. 环境搭建与游戏机制解析
在PyGame中重建Flappy Bird只需要不到100行代码,但这个简单游戏蕴含着强化学习的绝佳教学场景。让我们先拆解游戏的核心机制:
import pygame import random # 初始化游戏 pygame.init() screen = pygame.display.set_mode((400, 600)) clock = pygame.time.Clock() # 小鸟物理参数 bird_y = 300 bird_velocity = 0 gravity = 0.25 flap_strength = -5游戏状态可以简化为三个关键参数:
- 垂直距离:小鸟与下一个管道开口中心的垂直差值
- 水平距离:小鸟与下一个管道开口的水平距离
- 当前速度:小鸟的瞬时垂直速度
提示:在强化学习中,状态设计直接影响训练效果。过于简单的状态表示可能导致AI无法学习复杂策略。
我们设计的奖励函数如下表所示:
| 事件 | 即时奖励 | 说明 |
|---|---|---|
| 存活一帧 | +0.1 | 鼓励延长生存时间 |
| 通过管道 | +1 | 主要目标奖励 |
| 撞击障碍 | -1000 | 强烈惩罚终止行为 |
| 超出边界 | -1000 | 防止逃避策略 |
2. Q-learning实战:从零构建决策表格
Q-learning的核心是构建一个"决策手册"——Q表格,它记录了在特定状态下采取某个动作的长期价值。对于我们的Flappy Bird:
import numpy as np # 离散化状态空间 vertical_bins = np.linspace(-200, 200, 20) # 垂直距离分20档 horizontal_bins = np.linspace(0, 400, 20) # 水平距离分20档 velocity_bins = np.linspace(-8, 8, 10) # 速度分10档 # 初始化Q表格 (状态1 × 状态2 × 状态3 × 动作) q_table = np.zeros((20, 20, 10, 2))训练过程中的关键参数配置:
# 超参数设置 LEARNING_RATE = 0.1 DISCOUNT_FACTOR = 0.95 EPISODES = 10000 epsilon = 1.0 # 初始探索率 EPSILON_DECAY = 0.9995训练循环的核心逻辑:
for episode in range(EPISODES): state = env.reset() done = False while not done: # ε-greedy策略 if random.random() < epsilon: action = random.randint(0, 1) # 随机探索 else: action = np.argmax(q_table[state]) # 利用已知知识 next_state, reward, done = env.step(action) # Q值更新公式 current_q = q_table[state + (action,)] max_next_q = np.max(q_table[next_state]) new_q = current_q + LEARNING_RATE * (reward + DISCOUNT_FACTOR * max_next_q - current_q) q_table[state + (action,)] = new_q state = next_state epsilon *= EPSILON_DECAY # 衰减探索率注意:Q-learning面临维度灾难——当状态变量增加或精度要求提高时,Q表格会指数级膨胀。这就是我们需要深度强化学习的原因。
3. DQN进阶:用神经网络替代Q表格
Deep Q-Network (DQN) 用神经网络参数化Q函数,解决了状态空间爆炸问题。我们使用PyTorch构建一个简单的CNN:
import torch import torch.nn as nn class DQN(nn.Module): def __init__(self, input_shape, n_actions): super(DQN, self).__init__() self.conv = nn.Sequential( nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU() ) conv_out_size = self._get_conv_out(input_shape) self.fc = nn.Sequential( nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, n_actions) ) def _get_conv_out(self, shape): o = self.conv(torch.zeros(1, *shape)) return int(np.prod(o.size())) def forward(self, x): conv_out = self.conv(x).view(x.size()[0], -1) return self.fc(conv_out)DQN引入了两个关键技术改进:
经验回放(Experience Replay):
class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.buffer, batch_size)目标网络(Target Network):
target_net = DQN(input_shape, n_actions).to(device) target_net.load_state_dict(policy_net.state_dict()) target_update_counter = 0
4. 训练技巧与性能优化
在实际训练中,我们发现几个关键技巧能显著提升表现:
学习率调度:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.9)帧堆叠(Frame Stacking):
class FrameStack: def __init__(self, env, k): self.env = env self.k = k # 堆叠帧数 self.frames = deque([], maxlen=k) def reset(self): obs = self.env.reset() for _ in range(self.k): self.frames.append(obs) return self._get_obs() def step(self, action): obs, reward, done, info = self.env.step(action) self.frames.append(obs) return self._get_obs(), reward, done, info def _get_obs(self): return np.concatenate(list(self.frames), axis=0)Double DQN (DDQN)改进:
# 普通DQN的Q值计算 q_values = policy_net(states).gather(1, actions) # DDQN的Q值计算 with torch.no_grad(): next_actions = policy_net(next_states).max(1)[1].unsqueeze(1) next_q_values = target_net(next_states).gather(1, next_actions)训练过程中的典型性能指标变化:
| 训练阶段 | 平均得分 | 最大得分 | 存活时间 |
|---|---|---|---|
| 初期(0-1k) | 1.2 | 5 | 30帧 |
| 中期(1k-5k) | 15.7 | 42 | 500帧 |
| 后期(5k+) | 68.3 | ∞ | 2000+帧 |
5. 可视化分析与调试技巧
理解AI决策过程的关键是可视化。我们开发了几个调试工具:
Q值热力图:
def plot_q_values(state): q_values = model(torch.FloatTensor(state).unsqueeze(0)) plt.imshow(q_values.detach().numpy(), cmap='hot', interpolation='nearest') plt.colorbar() plt.show()策略轨迹回放:
def replay_episode(model, env): state = env.reset() frames = [] done = False while not done: frames.append(env.render(mode='rgb_array')) action = model.act(state) state, _, done, _ = env.step(action) return frames常见问题排查指南:
AI完全不学习:
- 检查奖励函数设计
- 验证梯度是否在更新
- 调整学习率和折扣因子
表现波动大:
- 增大经验回放缓冲区
- 降低探索率衰减速度
- 尝试DDQN结构
过拟合当前环境:
- 引入随机初始条件
- 使用课程学习策略
- 添加正则化项
在Colab笔记本上运行完整代码后,你会看到AI从最初的随机乱飞到最终能无限生存的完整进化过程。有趣的是,AI往往会发展出与人类不同的策略——比如紧贴管道上沿飞行以减少垂直移动幅度。
