当前位置: 首页 > news >正文

[学习笔记]强化学习之actor-critic

继续,策略梯度之后就是actor-critic

策略梯度很好用,是吧?

那么就面临了一个问题

直接拟合策略,不像通过价值函数那样,可以通过简单方式来判断其优劣

直接表现在训练上就是,它比dqn要难以收敛。

虽然在倒立摆这个简单的环境上没有什么体现

但是直观理解一下:

虽然reinforce中的网络是直接拟合所谓策略,但是我们并不知道网络里到底是如何拟合策略的

(相对于DQN,虽然我们也不知道它怎么拟合Q-Table的,但是可以想象到它是一个多元函数,像多重感知机那样拟合出来一个复杂的函数,最终对于每个输入量都可以输出一个数值)

所以,因为不清楚如何表达策略,我们也不知道,loss到底是如何评价一个策略的优劣的

这样就很麻烦啊,不能稳定地评价一个策略的好坏,怎么学习和优化呢

这就是actor-critic所解决的问题

其实,actor-critic的思路也很简单,如果网络无法自己判断策略的好坏,那就再加一个判断好坏的工具不就行了

那么,我们之前是如何判断一个网络的好坏的呢?

诶!价值函数

那么,就在策略网络旁边加一个判断策略的价值函数网络不就成了

问题就解决了,就这么简单粗暴

actor-critic算法也不是什么困难的算法,还用不到数学推导,用直观理解就可以解决了

贴一下代码,这一切都是在为下一个boss做铺垫:trpo

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
import time# Hyper Parameters
BATCH_SIZE = 32
LR = 0.01                   # learning rate
GAMMA = 0.9                 # reward discount
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v1', render_mode='human')
env = env.unwrapped
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]# Actor网络 - 基于原始网络结构,输出动作概率class ActorNet(nn.Module):def __init__(self):super(ActorNet, self).__init__()self.fc1 = nn.Linear(N_STATES, 50)self.fc1.weight.data.normal_(0, 0.1)self.out = nn.Linear(50, N_ACTIONS)self.out.weight.data.normal_(0, 0.1)def forward(self, x):x = self.fc1(x)x = F.relu(x)action_scores = self.out(x)return F.softmax(action_scores, dim=-1)  # 输出动作概率# Critic网络 - 基于原始网络结构,输出状态价值class CriticNet(nn.Module):def __init__(self):super(CriticNet, self).__init__()self.fc1 = nn.Linear(N_STATES, 50)self.fc1.weight.data.normal_(0, 0.1)self.out = nn.Linear(50, 1)  # 输出单个状态价值self.out.weight.data.normal_(0, 0.1)def forward(self, x):x = self.fc1(x)x = F.relu(x)state_value = self.out(x)return state_valueclass ActorCritic(object):def __init__(self):self.actor_net = ActorNet()self.critic_net = CriticNet()self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=LR)self.critic_optimizer = torch.optim.Adam(self.critic_net.parameters(), lr=LR)# 经验回放缓冲区self.memory_counter = 0self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 3))  # s, a, r, s_, donedef choose_action(self, x):x = torch.unsqueeze(torch.FloatTensor(x), 0)action_probs = self.actor_net(x)# 根据概率分布采样动作action_dist = torch.distributions.Categorical(action_probs)action = action_dist.sample()log_prob = action_dist.log_prob(action)return action.item(), log_probdef store_transition(self, s, a, log_prob, r, s_, done):if isinstance(s, tuple):s = s[0]if isinstance(s_, tuple):s_ = s_[0]transition = np.hstack((s, [a, log_prob.item(), r], s_, [done]))index = self.memory_counter % MEMORY_CAPACITYself.memory[index, :] = transitionself.memory_counter += 1def learn(self):if self.memory_counter < BATCH_SIZE:return# 随机采样经验sample_index = np.random.choice(min(self.memory_counter, MEMORY_CAPACITY), BATCH_SIZE)b_memory = self.memory[sample_index, :]b_s = torch.FloatTensor(b_memory[:, :N_STATES])b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))b_log_prob = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])b_r = torch.FloatTensor(b_memory[:, N_STATES+2:N_STATES+3])b_s_ = torch.FloatTensor(b_memory[:, N_STATES+3:N_STATES+3+N_STATES])b_done = torch.BoolTensor(b_memory[:, -1].astype(bool))# Critic学习:计算TD误差current_values = self.critic_net(b_s).squeeze()next_values = self.critic_net(b_s_).squeeze().detach()target_values = b_r.squeeze() + GAMMA * next_values * (~b_done).float()# 状态价值函数值# Critic损失:价值函数拟合critic_loss = F.mse_loss(current_values, target_values)# Actor学习:使用优势函数# TD误差作为优势函数,利用critic评价进行梯度下降advantage = (target_values - current_values).detach()actor_loss = -(b_log_prob.squeeze() * advantage).mean()# 更新网络
        self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 创建Actor-Critic智能体
ac_agent = ActorCritic()print('\nTraining with Actor-Critic...')
time.sleep(2)for i_episode in range(400):s, info = env.reset()ep_r = 0step_count = 0while True:env.render()# 选择动作a, log_prob = ac_agent.choose_action(s)# 执行动作s_, r, terminated, truncated, info = env.step(a)done = terminated or truncated# 修改奖励(保持与原始代码一致)x, x_dot, theta, theta_dot = s_r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8r2 = (env.theta_threshold_radians - abs(theta)) / \env.theta_threshold_radians - 0.5r = r1 + r2# 存储经验
        ac_agent.store_transition(s, a, log_prob, r, s_, done)# 学习
        ac_agent.learn()ep_r += rstep_count += 1s = s_if done:print(f'Ep: {i_episode:3d} | 'f'Ep_r: {round(ep_r, 2):6.2f} | 'f'Steps: {step_count:3d}')breakenv.close()
View Code

 

http://www.jsqmd.com/news/439839/

相关文章:

  • 2026年加拿大移民中介发布:以和中移民为代表的标杆机构深度解析 - 品牌推荐
  • 2026年音乐喷泉与景区灯光秀首选:六通喷泉公司,以全链路实力定义水舞声光新标杆 - 深度智识库
  • 2026年淑森林品牌深度解析:技术壁垒与市场前景的客观分析 - 品牌推荐
  • 2026年音叉密度计生产厂家排行榜:国内品牌强势崛起,谁将领跑? - 品牌推荐大师1
  • 设计“可吃可降解”的包装膜分子,传统塑料难降解,颠覆淀粉基高分子优化,输出安全可食,防水的包装材料。
  • 飞跃大苹果:北京圣擎航空助您轻松购纽约机票全攻略 - 今日又土又金
  • 2026年淑森林品牌深度解析:技术壁垒与市场前景的客观分析。 - 品牌推荐
  • PicoServer 跨平台 Web 实战系列(二) 路由机制与 API 设计
  • 降AI后文章变得口语化怎么办?问题出在这2点 - 我要发一区
  • 主流VS小众:公众号平台排版软件哪个好用?深度对比5款编辑器 - 鹅鹅鹅ee
  • 知网、维普、万方AIGC检测有什么区别?选错平台白花钱 - 我要发一区
  • Chats .. 发布:全面支持最新的 gpt- 模型等
  • ArrayDeque双端队列--底层原理可视化
  • 2026年杭州美发美容化妆职校大盘点,这些学校值得关注!电竞技校/美发美容化妆中专/美容化妆专业中职,职校产品有哪些 - 品牌推荐师
  • 2026年四川阴离子交换树脂/阳离子交换树脂稳定耐用实力厂家 深耕行业多年 - 深度智识库
  • 3分钟搞懂深度学习AI:梯度下降:迷雾中的下山路
  • 嵌入式中通讯帧为什么总是用0xAA、0x55做帧头?
  • 为什么AI写的文章总能被检测出来?3个原因你可能不知道 - 我要发一区
  • Note - 动态 DP
  • 2026年商用充电桩、电动车充电桩推荐与评价,解决网络覆盖与稳定性痛点 - 深度智识库
  • 新春零食大礼包推荐:旺旺大礼包性价比高、种类多、适合送同事送小朋友的年货礼包 - Top品牌推荐官
  • 2026高热度国际高中对比怎么选?附国际高中与国际学校升学率清单 - 品牌2026
  • 介绍了LiveBindings格式化的几种进阶方法: * 使用表达式列格式化。 * 自定义绑定方法。 * 使用自定义表单方法格式化。 ...
  • Mysql的索引数量是否越多越好?为什么?
  • 工程师进阶必修:如何从项目中“挖出”高价值专利?
  • 支付宝立减金别浪费!2026最新回收攻略,两大正规平台教你省心变现 - 京回收小程序
  • 厨房计时器
  • LeetCode知识点总结 - 504
  • OOP - Abstraction
  • 用过才敢说!专科生必备的AI论文软件 —— 千笔写作工具