当前位置: 首页 > news >正文

从DDPG到MADDPG:给单智能体算法加上‘队友视野’需要改哪几行代码?

从DDPG到MADDPG:核心代码改造实战指南

当我们需要将单智能体强化学习算法扩展到多智能体场景时,MADDPG(Multi-Agent DDPG)提供了一种优雅的解决方案。本文将以代码对比的方式,逐步展示如何将一个标准的DDPG实现改造成支持多智能体协作的MADDPG版本。我们将聚焦于三个关键改造点:Critic网络的输入扩展、经验回放缓冲区的调整以及训练流程的协同优化。

1. 网络架构的改造:让Critic拥有全局视野

DDPG的Critic网络只需要评估单个智能体的状态-动作对,而MADDPG的核心创新在于让Critic在训练时能够访问所有智能体的信息。这种"集中式训练"的设计需要我们对网络结构进行针对性调整。

1.1 Critic网络的输入维度扩展

在DDPG中,Critic的输入通常是(state, action)对。我们需要将其扩展为接收所有智能体的联合状态和动作:

# DDPG Critic网络输入 class Critic(nn.Module): def __init__(self, state_dim, action_dim): super().__init__() self.fc1 = nn.Linear(state_dim + action_dim, 256) # MADDPG Critic网络输入 class MADDPGCritic(nn.Module): def __init__(self, state_dim, action_dim, n_agents): super().__init__() # 输入维度变为全局状态+所有智能体动作的拼接 self.fc1 = nn.Linear(state_dim + n_agents * action_dim, 256)

关键改动点:

  • 输入维度从state_dim + action_dim变为state_dim + n_agents * action_dim
  • 前向传播时需要拼接所有智能体的动作信息

1.2 Actor网络保持独立性

与Critic不同,Actor网络在执行阶段仍然只依赖局部观察,因此其结构无需改变:

# Actor网络(DDPG和MADDPG保持一致) class Actor(nn.Module): def __init__(self, obs_dim, action_dim): super().__init__() self.fc1 = nn.Linear(obs_dim, 256) self.fc2 = nn.Linear(256, action_dim) def forward(self, obs): x = F.relu(self.fc1(obs)) return torch.tanh(self.fc2(x)) # 假设动作空间在[-1,1]范围内

2. 经验回放缓冲区的改造

多智能体环境中的经验存储需要考虑各智能体观察的同步性,我们需要设计能够保存全局状态和个体观察的缓冲区。

2.1 多智能体经验存储结构

class MultiAgentReplayBuffer: def __init__(self, capacity, obs_dims, state_dim, action_dims): self.capacity = capacity self.n_agents = len(obs_dims) # 为每个智能体创建独立的观察存储 self.obs_buffers = [ np.zeros((capacity, dim)) for dim in obs_dims ] self.next_obs_buffers = [ np.zeros((capacity, dim)) for dim in obs_dims ] # 全局状态存储 self.state_buffer = np.zeros((capacity, state_dim)) self.next_state_buffer = np.zeros((capacity, state_dim)) # 动作和奖励存储 self.action_buffers = [ np.zeros((capacity, dim)) for dim in action_dims ] self.reward_buffers = [ np.zeros((capacity, 1)) for _ in range(self.n_agents) ] self.done_buffer = np.zeros((capacity, 1), dtype=np.float32) self.pos = 0 self.size = 0

2.2 经验存储接口的变化

def add(self, obs_list, actions, rewards, next_obs_list, state, next_state, done): # 存储每个智能体的独立观察 for i in range(self.n_agents): self.obs_buffers[i][self.pos] = obs_list[i] self.next_obs_buffers[i][self.pos] = next_obs_list[i] self.action_buffers[i][self.pos] = actions[i] self.reward_buffers[i][self.pos] = rewards[i] # 存储全局状态 self.state_buffer[self.pos] = state self.next_state_buffer[self.pos] = next_state self.done_buffer[self.pos] = done self.pos = (self.pos + 1) % self.capacity self.size = min(self.size + 1, self.capacity)

3. 训练流程的协同优化

MADDPG的训练需要协调多个智能体的参数更新,这要求我们对训练循环进行重构。

3.1 集中式Critic更新

def update_critics(self, agents, batch_size): # 采样批量经验 idx = np.random.randint(0, self.size, size=batch_size) # 准备所有智能体的数据 states = torch.FloatTensor(self.state_buffer[idx]).to(device) next_states = torch.FloatTensor(self.next_state_buffer[idx]).to(device) # 收集所有智能体的当前和下一个动作 all_actions = [] all_next_actions = [] for i, agent in enumerate(agents): obs = torch.FloatTensor(self.obs_buffers[i][idx]).to(device) next_obs = torch.FloatTensor(self.next_obs_buffers[i][idx]).to(device) # 当前策略动作 current_actions = agent.actor(obs) # 目标策略动作 next_actions = agent.target_actor(next_obs) all_actions.append(current_actions) all_next_actions.append(next_actions) # 拼接所有动作 joint_actions = torch.cat(all_actions, dim=1) joint_next_actions = torch.cat(all_next_actions, dim=1) # 为每个智能体更新Critic for i, agent in enumerate(agents): rewards = torch.FloatTensor(self.reward_buffers[i][idx]).to(device) dones = torch.FloatTensor(self.done_buffer[idx]).to(device) # 计算目标Q值 with torch.no_grad(): target_q = agent.target_critic(next_states, joint_next_actions) y = rewards + (1 - dones) * self.gamma * target_q # 计算当前Q值 current_q = agent.critic(states, joint_actions) # 更新Critic critic_loss = F.mse_loss(current_q, y) agent.critic_optimizer.zero_grad() critic_loss.backward() agent.critic_optimizer.step()

3.2 分布式Actor更新

def update_actors(self, agents, batch_size): idx = np.random.randint(0, self.size, size=batch_size) states = torch.FloatTensor(self.state_buffer[idx]).to(device) # 为每个智能体更新Actor for i, agent in enumerate(agents): obs = torch.FloatTensor(self.obs_buffers[i][idx]).to(device) # 获取当前智能体的动作 current_actions = agent.actor(obs) # 获取其他智能体的动作(固定参数) other_actions = [] for j, other_agent in enumerate(agents): if j != i: other_obs = torch.FloatTensor(self.obs_buffers[j][idx]).to(device) other_action = other_agent.actor(other_obs).detach() other_actions.append(other_action) # 拼接所有动作(当前智能体+其他智能体) if other_actions: all_actions = torch.cat([current_actions] + other_actions, dim=1) else: all_actions = current_actions # 计算策略梯度 actor_loss = -agent.critic(states, all_actions).mean() agent.actor_optimizer.zero_grad() actor_loss.backward() agent.actor_optimizer.step()

4. 实战中的关键调整与优化

在实际应用中,我们发现以下几个调整对MADDPG的性能有显著影响:

4.1 探索噪声的协调

在多智能体环境中,探索噪声的设置需要更加谨慎:

def get_action(self, obs, noise_scale=0.1): obs = torch.FloatTensor(obs).unsqueeze(0).to(device) action = self.actor(obs).squeeze(0).cpu().detach().numpy() # 使用衰减的噪声 noise = noise_scale * np.random.randn(*action.shape) return np.clip(action + noise, -1, 1) # 假设动作空间在[-1,1]范围内

4.2 训练稳定性的提升技巧

技巧DDPG实现MADDPG调整
目标网络更新单独更新同步更新所有智能体目标网络
经验回放统一采样确保同一批次包含同步的经验
学习率调度固定学习率可能需要更保守的学习率衰减

4.3 多智能体特有的超参数调整

# 典型MADDPG超参数配置 config = { 'actor_lr': 1e-4, # 通常比DDPG更小的学习率 'critic_lr': 1e-3, 'tau': 0.01, # 目标网络软更新参数 'gamma': 0.95, # 折扣因子 'batch_size': 1024, # 更大的批次以稳定训练 'buffer_size': int(1e6), # 更大的回放缓冲区 'noise_start': 0.3, # 初始探索噪声 'noise_decay': 0.9995 # 噪声衰减率 }

在将DDPG扩展到多智能体场景时,最大的挑战不是算法原理的理解,而是工程实现上的细节处理。特别是在处理多个智能体的经验同步、网络参数更新顺序等实际问题时,需要格外注意数据的一致性和训练的稳定性。

http://www.jsqmd.com/news/708202/

相关文章:

  • ComfyUI-Impact-Pack插件安装指南:3步搞定AI图像增强完整配置
  • 盘点2026年重庆买卧室家具公司,源点宜联购排名如何 - 工业设备
  • 聊聊Mybatis-Plus中的10个坑!
  • 牛客网金三银四最新的 java 面试题及答案
  • 2026年国内外超声波液位差计十大品牌排名最新版 - 仪表人小余
  • 避开这些坑!ESP32-WROVER模组PSRAM使用全指南(含硬件连接与版本差异)
  • Cortex-M55向量指令集:嵌入式SIMD加速与DSP优化
  • 2026年环保裂解设备公司排行榜,四海能源性价比非常高 - 工业设备
  • 2026年江浙沪皖回转支承实力供应商排名,前十有哪些 - 工业设备
  • Diablo Edit2:暗黑破坏神II角色编辑器,5分钟打造完美角色的终极秘籍
  • 2026年西南换电加盟创业完全指南:低成本运营模式深度横评与B端选型避坑指南 - 优质企业观察收录
  • 2026年西南换电加盟创业指南:低成本高效率运营方案对标与官方直联渠道 - 优质企业观察收录
  • 闲鱼自动化数据采集系统:终极配置指南与智能监控解决方案
  • 2026年亚固官方联系方式公示,门锁五金一站式服务合作便捷入口 - 第三方测评
  • 用啤酒和牛奶讲明白:Ecoinvent里Cutoff、Consequential、APOS到底有啥不一样?
  • 2026年中国热门的吉利远程商用车公司推荐,天津地区靠谱的有哪些 - 工业设备
  • Headless Chrome实战:从Docker快速玩转到K8s生产部署,附Java连接避坑指南
  • 避坑指南:TDengine 3.0.2.6连接DBeaver最全配置流程(含JDBC驱动编译与两种驱动方式详解)
  • 2026最新墨西哥海运专线/墨西哥空派小包专线公司推荐!广东优质权威榜单发布,实力靠谱广州等地物流服务商精选 - 博客万
  • Unity新手避坑:用CharacterController搞定第一人称移动与跳跃(含地面检测详解)
  • 7天掌握数据科学核心技能:零基础实战入门指南
  • 2026年宁波定制伸缩门选购,口碑好的品牌排名 - 工业设备
  • STM32F103 SDIO读写SD卡,从硬件焊接到HAL库配置的完整避坑指南
  • PCIe 6.0都来了,你的项目还在用Gen3?聊聊编码演进史与选型指南(8B/10B到PAM-4)
  • 别再手动截图了!一个Python脚本搞定.dat数据到图片的自动转换与归档
  • 2026年全国风机采购完全指南:湖北消防排烟与工业风机厂家深度横评 - 优质企业观察收录
  • 2026年度全国废气处理设备及配套服务品牌综合测评报告 - 深度智识库
  • Weka回归项目实战:从数据探索到模型优化
  • R语言机器学习数据预处理全流程实战指南
  • SAP Fiori Excel 导出升级,SmartTable 终于把界面里的选择带进了 Excel