当前位置: 首页 > news >正文

用Python和TensorFlow训练AI玩贪吃蛇:从游戏逻辑到DQN算法实战(附完整代码)

用Python和TensorFlow训练AI玩贪吃蛇:从游戏逻辑到DQN算法实战

贪吃蛇这个经典游戏,几乎每个人都玩过。但你是否想过,让AI来玩这个游戏会是什么样子?本文将带你从零开始,用Python和TensorFlow构建一个能够自主玩贪吃蛇的AI系统。不同于简单的规则式AI,我们将使用深度强化学习中的DQN算法,让AI真正"学会"如何玩这个游戏。

1. 项目准备与环境搭建

在开始编码之前,我们需要准备好开发环境。这个项目需要以下几个主要组件:

  • Python 3.7或更高版本
  • Pygame库(用于游戏界面)
  • TensorFlow 2.x(用于构建和训练神经网络)
  • NumPy(用于数值计算)

安装这些依赖非常简单,只需在命令行中执行以下命令:

pip install pygame tensorflow numpy

对于硬件要求,虽然可以在CPU上运行,但如果有NVIDIA显卡并安装了CUDA,训练速度会显著提升。建议至少4GB内存,因为神经网络训练过程会比较消耗资源。

项目目录结构建议如下:

/snake_ai /game __init__.py snake.py # 游戏逻辑 render.py # 游戏渲染 /rl __init__.py dqn.py # DQN算法实现 memory.py # 经验回放缓冲区 config.py # 配置文件 train.py # 训练脚本 play.py # 人类游玩脚本

2. 贪吃蛇游戏逻辑实现

首先我们需要构建贪吃蛇游戏的基本框架。使用Pygame可以方便地创建游戏窗口和处理用户输入。

2.1 游戏核心类设计

我们创建三个主要类:SnakeFoodGame。下面是Snake类的核心代码:

class Snake: def __init__(self, block_size=20, width=800, height=600): self.length = 3 self.positions = [(width // 2, height // 2)] self.direction = random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)]) self.block_size = block_size self.width = width self.height = height self.color = (0, 255, 0) # 绿色 def get_head_position(self): return self.positions[0] def turn(self, new_direction): # 防止180度转弯 if (new_direction[0] * -1, new_direction[1] * -1) != self.direction: self.direction = new_direction def move(self): head = self.get_head_position() x, y = self.direction new_x = (head[0] + (x * self.block_size)) % self.width new_y = (head[1] + (y * self.block_size)) % self.height new_position = (new_x, new_y) self.positions.insert(0, new_position) if len(self.positions) > self.length: self.positions.pop() def reset(self): self.length = 3 self.positions = [(self.width // 2, self.height // 2)] self.direction = random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)]) def draw(self, surface): for p in self.positions: rect = pygame.Rect((p[0], p[1]), (self.block_size, self.block_size)) pygame.draw.rect(surface, self.color, rect) pygame.draw.rect(surface, (0, 0, 0), rect, 1)

2.2 游戏主循环

游戏主循环负责处理输入、更新游戏状态和渲染画面:

class Game: def __init__(self, width=800, height=600, block_size=20): pygame.init() self.screen = pygame.display.set_mode((width, height)) self.clock = pygame.time.Clock() self.snake = Snake(block_size, width, height) self.food = Food(block_size, width, height) self.width = width self.height = height self.block_size = block_size self.score = 0 def run(self): running = True while running: for event in pygame.event.get(): if event.type == pygame.QUIT: running = False elif event.type == pygame.KEYDOWN: if event.key == pygame.K_UP: self.snake.turn((0, -1)) elif event.key == pygame.K_DOWN: self.snake.turn((0, 1)) elif event.key == pygame.K_LEFT: self.snake.turn((-1, 0)) elif event.key == pygame.K_RIGHT: self.snake.turn((1, 0)) self.snake.move() # 检测是否吃到食物 if self.snake.get_head_position() == self.food.position: self.snake.length += 1 self.score += 1 self.food = Food(self.block_size, self.width, self.height) # 检测碰撞 if self.snake.get_head_position() in self.snake.positions[1:]: print(f"Game Over! Score: {self.score}") self.snake.reset() self.score = 0 # 渲染 self.screen.fill((255, 255, 255)) self.snake.draw(self.screen) self.food.draw(self.screen) pygame.display.update() self.clock.tick(10) # 控制游戏速度 pygame.quit()

3. DQN算法原理与实现

深度Q网络(DQN)是强化学习中的一种重要算法,它结合了Q-learning和深度神经网络的优点。

3.1 DQN核心概念

DQN的核心思想是使用神经网络来近似Q函数,即状态-动作值函数。Q函数表示在某个状态下采取某个动作所能获得的预期回报。

DQN有几个关键组件:

  1. 经验回放(Experience Replay):存储智能体的经验(状态,动作,奖励,新状态)在记忆库中,训练时从中随机采样,打破数据间的相关性。
  2. 目标网络(Target Network):使用一个独立的网络来计算目标Q值,提高训练稳定性。
  3. ε-贪婪策略(ε-Greedy Policy):在探索和利用之间取得平衡,开始时更多探索,逐渐增加利用。

3.2 DQN实现代码

下面是DQN的核心实现:

import numpy as np import tensorflow as tf from collections import deque import random class DQNAgent: def __init__(self, state_size, action_size): self.state_size = state_size self.action_size = action_size self.memory = deque(maxlen=2000) self.gamma = 0.95 # 折扣因子 self.epsilon = 1.0 # 探索率 self.epsilon_min = 0.01 self.epsilon_decay = 0.995 self.learning_rate = 0.001 self.model = self._build_model() self.target_model = self._build_model() self.update_target_model() def _build_model(self): model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(24, input_dim=self.state_size, activation='relu')) model.add(tf.keras.layers.Dense(24, activation='relu')) model.add(tf.keras.layers.Dense(self.action_size, activation='linear')) model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.learning_rate)) return model def update_target_model(self): self.target_model.set_weights(self.model.get_weights()) def remember(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def act(self, state): if np.random.rand() <= self.epsilon: return random.randrange(self.action_size) act_values = self.model.predict(state) return np.argmax(act_values[0]) def replay(self, batch_size): if len(self.memory) < batch_size: return minibatch = random.sample(self.memory, batch_size) states = np.array([i[0] for i in minibatch]) actions = np.array([i[1] for i in minibatch]) rewards = np.array([i[2] for i in minibatch]) next_states = np.array([i[3] for i in minibatch]) dones = np.array([i[4] for i in minibatch]) states = np.squeeze(states) next_states = np.squeeze(next_states) targets = rewards + self.gamma * (np.amax(self.target_model.predict_on_batch(next_states), axis=1)) * (1 - dones) targets_full = self.model.predict_on_batch(states) ind = np.array([i for i in range(batch_size)]) targets_full[[ind], [actions]] = targets self.model.fit(states, targets_full, epochs=1, verbose=0) if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay def load(self, name): self.model.load_weights(name) def save(self, name): self.model.save_weights(name)

4. 训练AI玩贪吃蛇

现在我们将游戏环境和DQN算法结合起来,训练AI玩贪吃蛇。

4.1 状态表示

我们需要定义如何将游戏状态表示为神经网络可以理解的输入。对于贪吃蛇游戏,状态可以包括:

  • 蛇头周围四个方向是否有障碍(蛇身或墙壁)
  • 食物相对于蛇头的位置(左/右/上/下)
  • 蛇当前的移动方向
def get_state(self): head = self.snake.get_head_position() food = self.food.position # 计算四个方向的点 point_l = (head[0] - self.block_size, head[1]) point_r = (head[0] + self.block_size, head[1]) point_u = (head[0], head[1] - self.block_size) point_d = (head[0], head[1] + self.block_size) # 当前移动方向 dir_l = self.snake.direction == (-1, 0) dir_r = self.snake.direction == (1, 0) dir_u = self.snake.direction == (0, -1) dir_d = self.snake.direction == (0, 1) state = [ # 危险直行 (dir_r and self.is_collision(point_r)) or (dir_l and self.is_collision(point_l)) or (dir_u and self.is_collision(point_u)) or (dir_d and self.is_collision(point_d)), # 危险右转 (dir_u and self.is_collision(point_r)) or (dir_d and self.is_collision(point_l)) or (dir_l and self.is_collision(point_u)) or (dir_r and self.is_collision(point_d)), # 危险左转 (dir_d and self.is_collision(point_r)) or (dir_u and self.is_collision(point_l)) or (dir_r and self.is_collision(point_u)) or (dir_l and self.is_collision(point_d)), # 移动方向 dir_l, dir_r, dir_u, dir_d, # 食物位置 food[0] < head[0], # 食物在左 food[0] > head[0], # 食物在右 food[1] < head[1], # 食物在上 food[1] > head[1] # 食物在下 ] return np.array(state, dtype=int)

4.2 奖励函数设计

奖励函数是强化学习中最关键的部分之一,它告诉AI什么是好的行为,什么是坏的行为。对于贪吃蛇游戏,我们可以设计如下奖励:

  • 吃到食物:+10
  • 撞到自己或墙壁:-10
  • 靠近食物:+1
  • 远离食物:-1
  • 每移动一步:-0.1(鼓励高效)
def get_reward(self, snake, food, done): if done: return -10 if snake.get_head_position() == food.position: return 10 # 计算与食物的距离 head = snake.get_head_position() food_pos = food.position new_dist = abs(head[0] - food_pos[0]) + abs(head[1] - food_pos[1]) # 如果距离减小,给予正奖励;否则负奖励 if new_dist < self.prev_distance: reward = 1 else: reward = -1 self.prev_distance = new_dist # 每步的小惩罚 reward -= 0.1 return reward

4.3 训练过程

训练过程主要包括以下步骤:

  1. 初始化环境和智能体
  2. 获取当前状态
  3. 智能体选择动作
  4. 执行动作,获取新状态和奖励
  5. 存储经验到记忆库
  6. 训练智能体
  7. 定期更新目标网络
def train(): pygame.init() width, height, block_size = 800, 600, 20 game = Game(width, height, block_size) agent = DQNAgent(state_size=11, action_size=3) # 3动作:直行、右转、左转 episodes = 1000 batch_size = 32 for e in range(episodes): game.reset() state = game.get_state() state = np.reshape(state, [1, 11]) total_reward = 0 while True: action = agent.act(state) # 执行动作 if action == 0: # 直行 pass elif action == 1: # 右转 if game.snake.direction == (0, -1): game.snake.turn((1, 0)) elif game.snake.direction == (1, 0): game.snake.turn((0, 1)) elif game.snake.direction == (0, 1): game.snake.turn((-1, 0)) elif game.snake.direction == (-1, 0): game.snake.turn((0, -1)) elif action == 2: # 左转 if game.snake.direction == (0, -1): game.snake.turn((-1, 0)) elif game.snake.direction == (-1, 0): game.snake.turn((0, 1)) elif game.snake.direction == (0, 1): game.snake.turn((1, 0)) elif game.snake.direction == (1, 0): game.snake.turn((0, -1)) game.snake.move() # 检查游戏状态 done = False if game.snake.get_head_position() in game.snake.positions[1:]: done = True # 检查是否吃到食物 if game.snake.get_head_position() == game.food.position: game.snake.length += 1 game.food = Food(block_size, width, height) # 获取奖励和新状态 reward = game.get_reward(game.snake, game.food, done) total_reward += reward next_state = game.get_state() next_state = np.reshape(next_state, [1, 11]) # 存储经验 agent.remember(state, action, reward, next_state, done) state = next_state if done: print(f"Episode: {e}/{episodes}, Score: {game.snake.length}, Total reward: {total_reward}, Epsilon: {agent.epsilon:.2f}") break if len(agent.memory) > batch_size: agent.replay(batch_size) # 定期更新目标网络 if e % 10 == 0: agent.update_target_model() # 定期保存模型 if e % 100 == 0: agent.save(f"snake_dqn_{e}.h5") agent.save("snake_dqn_final.h5")

5. 调优与改进

训练过程中,你可能会遇到AI表现不佳的情况。以下是几个常见的调优方向:

5.1 奖励函数调整

奖励函数的设计对训练效果影响巨大。可以尝试以下调整:

  • 增加对长时间存活的奖励
  • 调整靠近/远离食物的奖励幅度
  • 增加对形成循环移动的惩罚

5.2 网络结构优化

可以尝试更复杂的网络结构:

def _build_model(self): model = tf.keras.Sequential([ tf.keras.layers.Dense(64, input_dim=self.state_size, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(self.action_size, activation='linear') ]) model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=self.learning_rate)) return model

5.3 训练参数调整

关键训练参数包括:

参数建议值说明
γ (gamma)0.9-0.99折扣因子,越大表示越重视长期奖励
ε (epsilon)1.0→0.01探索率,初始高探索,逐渐降低
ε衰减0.995控制探索率降低速度
学习率0.0001-0.001影响权重更新幅度
批次大小32-64每次训练的样本数量
记忆容量1000-10000经验回放缓冲区大小

5.4 高级技巧

  1. 双DQN(Double DQN):使用两个网络分别选择动作和评估动作,减少过高估计问题。
  2. 优先级经验回放(Prioritized Experience Replay):给重要的经验样本更高采样概率。
  3. 决斗网络架构(Dueling Network):将Q值分解为状态值和优势函数。

实现双DQN只需修改replay方法:

def replay(self, batch_size): if len(self.memory) < batch_size: return minibatch = random.sample(self.memory, batch_size) states = np.array([i[0] for i in minibatch]) actions = np.array([i[1] for i in minibatch]) rewards = np.array([i[2] for i in minibatch]) next_states = np.array([i[3] for i in minibatch]) dones = np.array([i[4] for i in minibatch]) states = np.squeeze(states) next_states = np.squeeze(next_states) # 双DQN修改部分 next_actions = np.argmax(self.model.predict_on_batch(next_states), axis=1) q_values_next = self.target_model.predict_on_batch(next_states) targets = rewards + self.gamma * q_values_next[np.arange(batch_size), next_actions] * (1 - dones) targets_full = self.model.predict_on_batch(states) ind = np.array([i for i in range(batch_size)]) targets_full[[ind], [actions]] = targets self.model.fit(states, targets_full, epochs=1, verbose=0) if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay
http://www.jsqmd.com/news/989496/

相关文章:

  • 城市更新地标翻译:跨文化语境下的语言重塑与身份传达
  • 2026年新乡自动送料机厂家推荐榜单:化工厂/医药厂/新能源材料及锂电池行业精准投料设备优选 - 品牌发掘
  • Make Sense:浏览器端零安装的图像标注神器终极指南
  • 汽车电子测试耐高低温弹簧顶针优质供应商推荐:高精密pogopin/高频率pogopin连接器/优选指南 - 优质品牌商家
  • 一键下载全网视频:VideoDownloadHelper终极使用指南
  • STM32F103C8T6最小系统板直连OLED屏的Keil可运行工程(含SSD1306/SH1106驱动源码)
  • 3.1.5 平衡二叉树
  • 技术深度解析:Lapce远程SSH连接性能瓶颈与优化方案
  • GetQzonehistory:5分钟实现QQ空间历史数据完整备份的终极解决方案
  • 深度解析SageAttention量化注意力:3-5倍性能提升实战指南
  • 5分钟用AI看懂足球:体育视频智能分析实战指南
  • 密集检索中的查询感知维度选择优化方法
  • Moneta Markets亿汇:用清单方式看外汇行情信息呈现,更容易形成稳定判断
  • 洛雪音乐音源配置终极指南:三步打造你的个人无损音乐库
  • 2026年6月头部稻壳餐具模具源头厂家推荐,包装桶类模具/湿巾盖模具/刀叉勺类模具,稻壳餐具模具直销厂家推荐 - 品牌推荐师
  • 后端的异常和保护机制
  • 2026年 新疆酒店铝单板源头厂家推荐榜单:专业定制与匠心工艺品质之选 - 品牌发掘
  • Spring Boot项目里用Netty手搓一个MQTT客户端,从连接、订阅到消息重发全流程解析
  • 用Python+NetworkX模拟社交网络中的‘跟风’行为:一个演化博弈的实战案例
  • 手把手教你用Python复现STARFM时空融合算法:从Github代码到实战避坑
  • Revit2GLTF终极指南:专业级BIM模型到Web3D的高效转换解决方案
  • 让文献管理变得可视化:Zotero Style的5大创新功能
  • C语言项目实战:用uthash库给你的自定义数据结构建个高速‘查询缓存’
  • 边缘弱网环境下的离散节点高可用组网实践与全网通工业路由器选型指南
  • 遥感图像大坝检测数据集VOC+YOLO格式8350张1类别
  • AdaCNP:极端天气下电力负荷预测的概率建模方法
  • 13ft Ladder终极指南:3分钟搭建个人付费墙绕过工具
  • AI 辅助的 K8s 资源配额推荐:从经验估算到数据驱动
  • 期货量化程序 time.sleep 卡死:天勤单线程与 deadline 替代
  • 2026齐齐哈尔市老酒回收选购技术推荐 实用避坑解析 - 优质品牌商家