基于DQN的强化学习实战:从gymnasium环境搭建到pytorch模型优化
1. 从零搭建强化学习环境
第一次接触强化学习的朋友可能会被各种术语吓到,但实际操作起来你会发现,搭建一个基础的训练环境比想象中简单得多。我这里以经典的倒立摆控制任务为例,手把手带你走通整个流程。
首先需要安装两个核心工具包:
pip install gymnasium pytorchgymnasium是OpenAI gym的升级版,提供了标准化的强化学习环境接口。安装完成后,我们可以用一行代码创建倒立摆环境:
import gymnasium as gym env = gym.make("CartPole-v1")这个环境模拟的是小车上的倒立摆控制问题:
- 状态空间:4维向量(小车位置、速度、杆角度、角速度)
- 动作空间:2个离散动作(向左/向右推小车)
- 奖励机制:每存活一个时间步+1分,杆子倾斜过大或小车出界则终止
我建议先用随机策略测试下环境:
state, _ = env.reset() for _ in range(100): action = env.action_space.sample() # 随机选择动作 state, reward, terminated, truncated, _ = env.step(action) if terminated or truncated: break env.close()运行这段代码时,你会发现杆子很快就倒下了。这就是我们要用DQN解决的问题——让智能体学会平衡杆子的策略。在实际项目中,我习惯先可视化观察环境特征,这对后续设计神经网络结构很有帮助。
2. DQN算法核心实现
深度Q网络(DQN)的核心思想是用神经网络替代传统Q-learning中的Q表。我们先来看网络结构设计,这里采用三层全连接网络:
class DQN(nn.Module): def __init__(self, n_observations, n_actions): super().__init__() self.layer1 = nn.Linear(n_observations, 128) self.layer2 = nn.Linear(128, 128) self.layer3 = nn.Linear(128, n_actions) def forward(self, x): x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) return self.layer3(x)这个结构有几个设计要点:
- 输入层维度必须与环境的状态空间一致(CartPole是4)
- 输出层维度等于动作空间大小(CartPole是2)
- 隐藏层使用ReLU激活函数避免梯度消失
- 没有在最后一层加softmax,因为我们需要的是Q值而非概率
实际训练中我发现,经验回放机制是DQN成功的关键。它解决了两个核心问题:
- 消除样本间的时序相关性
- 提高数据利用率
实现起来也很简单:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward')) class ReplayMemory: def __init__(self, capacity): self.memory = deque([], maxlen=capacity) def push(self, *args): self.memory.append(Transition(*args)) def sample(self, batch_size): return random.sample(self.memory, batch_size)建议将回放池大小设为5000-10000,batch size设为128。太小的池会导致训练不稳定,太大的池又会拖慢学习速度。
3. 训练过程的关键技巧
训练循环是强化学习最核心的部分,这里我分享几个实测有效的调参经验:
探索-利用平衡:使用ε-greedy策略时,衰减率设置很关键:
EPS_START = 0.9 # 初始探索率 EPS_END = 0.05 # 最小探索率 EPS_DECAY = 1000 # 衰减速度我建议先用较高的初始探索率(0.8-0.9),然后缓慢衰减。太快的衰减会导致智能体过早陷入局部最优。
双网络机制:DQN使用两个网络来稳定训练:
policy_net = DQN(n_observations, n_actions).to(device) target_net = DQN(n_observations, n_actions).to(device) target_net.load_state_dict(policy_net.state_dict())更新目标网络时采用软更新策略:
TAU = 0.005 # 更新系数 for key in policy_net_state_dict: target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)损失函数选择:Smooth L1 Loss比MSE更适合Q值更新:
criterion = nn.SmoothL1Loss() loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))在我的实验中,加入梯度裁剪能显著提升稳定性:
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)4. 模型优化与性能调优
当基础版本跑通后,我们可以从以下几个方向进一步提升性能:
网络结构优化:
- 尝试增加隐藏层宽度(256/512神经元)
- 加入batch normalization层
- 使用更复杂的结构如Dueling DQN
超参数调优:
BATCH_SIZE = 256 # 增大batch size GAMMA = 0.999 # 增大折扣因子 LR = 3e-4 # 调整学习率训练策略改进:
- 实现优先经验回放(Prioritized Experience Replay)
- 添加n-step bootstrap
- 使用NoisyNet替代ε-greedy
我常用的性能监控方法是绘制episode duration曲线:
def plot_durations(): plt.figure(1) durations_t = torch.tensor(episode_durations, dtype=torch.float) if len(durations_t) >= 100: means = durations_t.unfold(0, 100, 1).mean(1).view(-1) means = torch.cat((torch.zeros(99), means)) plt.plot(means.numpy())当100episode平均得分达到环境最大值(CartPole是500)时,说明模型已经收敛。如果发现曲线波动很大,可以尝试减小学习率或增大batch size。
5. 常见问题排查指南
在实践过程中,我遇到过不少坑,这里总结几个典型问题:
训练完全不收敛:
- 检查reward设计是否合理
- 确认状态归一化是否正确
- 尝试调大探索率ε
模型表现忽好忽坏:
- 可能是回放池太小导致过拟合
- 尝试降低学习率
- 检查目标网络更新频率
GPU内存不足:
- 减小batch size
- 使用梯度累积
- 精简网络结构
一个实用的debug技巧是可视化智能体的决策过程:
def visualize_policy(): state, _ = env.reset() for t in range(500): with torch.no_grad(): action = policy_net(state).max(1).indices.view(1,1) state, _, terminated, truncated, _ = env.step(action.item()) env.render() if terminated or truncated: break如果发现智能体做出明显不合理的行为,可能是网络结构或reward设计有问题。我在第一次实现时就遇到过小车一直朝一个方向移动的情况,后来发现是reward函数没有考虑位置偏移的惩罚。
6. 扩展应用与进阶方向
掌握基础DQN后,可以尝试更复杂的应用场景:
其他经典环境:
- MountainCar(连续状态离散动作)
- LunarLander(多维度控制)
- Atari游戏(图像输入)
算法变种:
- Double DQN:解决Q值过估计
- Dueling DQN:分离状态价值和优势函数
- Rainbow:整合多种改进技巧
实际工程优化:
- 使用torch.compile加速模型
- 实现分布式经验回放
- 添加模型检查点和恢复功能
我在一个工业控制项目中应用DQN时,发现将PyTorch模型转换为ONNX格式后,推理速度能提升2-3倍。这对于实时性要求高的场景非常有用:
dummy_input = torch.randn(1, n_observations) torch.onnx.export(policy_net, dummy_input, "model.onnx")强化学习的魅力在于,同一个算法框架可以通过调整适应各种不同的问题。建议从CartPole这样的简单环境开始,逐步挑战更复杂的任务,这样能更好地理解算法本质。
