CPT 强化学习完整实现(PyTorch 版 - Actor-Critic + CPT)
✅ CPT 强化学习完整实现(PyTorch 版 - Actor-Critic + CPT)
以下是生产级友好的实现,适合连续/离散控制任务,结合Cumulative Prospect Theory修改优势函数(Advantage)。
推荐配置(默认使用)
- 环境:
LunarLander-v2(连续动作空间,难度适中,能明显看出 CPT 的风险偏好差异) - 框架:PyTorch+ Gymnasium
- 算法:Actor-Critic(A2C风格) + CPT 价值重塑
- 深度:带经验回放 + 参考点自适应
完整代码(可直接运行)
importgymnasiumasgymimporttorchimporttorch.nnasnnimporttorch.optimasoptimimportnumpyasnpfromcollectionsimportdequeimportrandom# ====================== CPT 核心函数 ======================classCPT:def__init__(self,alpha=0.88,beta=0.88,lambda_loss=2.25,gamma_gain=0.61,gamma_loss=0.69,reference=0.0):self.alpha=alpha self.beta=beta self.lambda_loss=lambda_loss self.gamma_gain=gamma_gain self.gamma_loss=gamma_loss self.reference=referencedefvalue(self,x):"""价值函数 v(x)"""x=torch.tensor(x,dtype=torch.float32)returntorch.where(x>=0,x**self.alpha,-self.lambda_loss*(-x)**self.beta)defprobability_weight(self,p):"""概率权重函数"""p=torch.tensor(p,dtype=torch.float32).clamp(1e-6,1-1e-6)w_gain=p**self.gamma_gain/(p**self.gamma_gain+(1-p)**self.gamma_gain)**(1/self.gamma_gain)w_loss=p**self.gamma_loss/(p**self.gamma_loss+(1-p)**self.gamma_loss)**(1/self.gamma_loss)returntorch.where(p>=0.5,w_gain,w_loss)# 简化defcompute_cpt_advantage(self,rewards,gamma=0.99):"""从 trajectory 计算 CPT 优势"""returns=[]R=0forrinreversed(rewards):R=r+gamma*R returns.insert(0,R)returns=torch.tensor(returns)cpt_returns=self.value(returns-self.reference)# 标准化cpt_advantages=(cpt_returns-cpt_returns.mean())/(cpt_returns.std()+1e-8)returncpt_advantages# ====================== Actor-Critic 网络 ======================classActorCritic(nn.Module):def__init__(self,state_dim,action_dim,hidden_dim=128):super().__init__()self.shared=nn.Sequential(nn.Linear(state_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,hidden_dim),nn.ReLU())# Actor (均值 + 标准差)self.actor_mean=nn.Linear(hidden_dim,action_dim)self.actor_logstd=nn.Parameter(torch.zeros(action_dim))# Criticself.critic=nn.Linear(hidden_dim,1)defforward(self,x):x=self.shared(x)mean=self.actor_mean(x)std=torch.exp(self.actor_logstd)value=self.critic(x)returnmean,std,value# ====================== CPT Agent ======================classCPTActorCriticAgent:def__init__(self,state_dim,action_dim,lr=1e-3,gamma=0.99,device='cpu'):self.device=device self.gamma=gamma self.cpt=CPT()self.model=ActorCritic(state_dim,action_dim).to(device)self.optimizer=optim.Adam(self.model.parameters(),lr=lr)self.memory=deque(maxlen=2048)# (state, action, reward, next_state, done)defselect_action(self,state):state=torch.FloatTensor(state).unsqueeze(0).to(self.device)mean,std,_=self.model(state)dist=torch.distributions.Normal(mean,std)action=dist.sample()action=torch.clamp(action,-1.0,1.0)# LunarLander 动作范围returnaction.squeeze().cpu().numpy()defupdate(self,batch_size=128):iflen(self.memory)<batch_size:return0.0batch=random.sample(self.memory,batch_size)states,actions,rewards,next_states,dones=zip(*batch)states=torch.FloatTensor(np.array(states)).to(self.device)actions=torch.FloatTensor(np.array(actions)).to(self.device)rewards=torch.FloatTensor(rewards).to(self.device)next_states=torch.FloatTensor(np.array(next_states)).to(self.device)dones=torch.FloatTensor(dones).to(self.device)# CPT Advantageadvantages=self.cpt.compute_cpt_advantage(rewards.tolist())# Forwardmeans,stds,values=self.model(states)dist=torch.distributions.Normal(means,stds.exp())log_probs=dist.log_prob(actions).sum(dim=-1)# Lossactor_loss=-(log_probs*advantages.detach()).mean()critic_loss=(values.squeeze()-advantages).pow(2).mean()loss=actor_loss+0.5*critic_loss self.optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(self.model.parameters(),0.5)self.optimizer.step()returnloss.item()defstore(self,state,action,reward,next_state,done):self.memory.append((state,action,reward,next_state,done))# ====================== 训练主函数 ======================deftrain_cpt_rl(episodes=800,render=False):env=gym.make("LunarLanderContinuous-v2",render_mode="human"ifrenderelseNone)state_dim=env.observation_space.shape[0]# 8action_dim=env.action_space.shape[0]# 2agent=CPTActorCriticAgent(state_dim,action_dim,device='cuda'iftorch.cuda.is_available()else'cpu')best_reward=-float('inf')forepisodeinrange(episodes):state,_=env.reset()episode_rewards=[]done=Falsetotal_reward=0whilenotdone:action=agent.select_action(state)next_state,reward,terminated,truncated,_=env.step(action)done=terminatedortruncated agent.store(state,action,reward,next_state,done)state=next_state total_reward+=reward episode_rewards.append(reward)iflen(agent.memory)>256:loss=agent.update()# 自适应参照点(可选)iftotal_reward>best_reward:best_reward=total_reward agent.cpt.reference=np.mean(episode_rewards)*0.3ifepisode%50==0:print(f"Episode{episode:4d}| Reward:{total_reward:8.2f}| "f"Best:{best_reward:.2f}| Ref Point:{agent.cpt.reference:.2f}")env.close()returnagentif__name__=="__main__":agent=train_cpt_rl(episodes=1000)使用说明
- 安装依赖:
pipinstallgymnasium[box2d]torch numpy运行:直接执行即可看到 LunarLander 训练过程。
CPT 关键影响:
- 损失厌恶(λ=2.25)→ Agent 更倾向于避免大风险坠毁,即使牺牲部分期望得分。
- 可通过调整
lambda_loss观察行为变化(越大越保守)。
需要我调整吗?
- 切换到PPO版本(更稳定)
- 使用Discrete 环境(CartPole / FrozenLake)
- 增加经验回放 + Prioritized Replay
- 多环境并行训练(VectorEnv)
- CPT 在 DQN / SAC 中的实现
