当前位置: 首页 > news >正文

PPO

for batch_prompt in prompt_dataset:batch_response = active_model.generate(batch_prompt)batch_data = concat(batch_prompt, batch_response)batch_scores = reward_model(batch_data)batch_all_probs, batch_probs, batch_all_values = active_model.forward_pass(batch_data)ref_all_probs, ref_probs, ref_all_values = ref_model.forward_pass(batch_data)kls = compute_KL(batch_all_probs, ref_all_probs)rewards = compute_rewards(batch_scores, kls)advantages = compute_advantages(batch_all_values, rewards)returns = advantages + batch_all_valuesfor i in range(epoch):active_all_probs, active_probs, active_all_values = active_model.forward_pass(batch_data)loss_state_value = torch.mean((returns - active_all_values) ** 2)ratio = active_probs / batch_probsloss_ppo = torch.mean(-advantages * ratio)loss = loss_ppo + value_loss_rate * loss_state_valueloss.backward()optimizer.step()optimizer.zero_grad()

上面的代码是PPO训练的整体代码,参考教学视频:

https://www.bilibili.com/video/BV1rixye7ET6?spm_id_from=333.788.videopod.sections&vd_source=da862fa7a218e81897b55d7e24fe26ee

https://www.bilibili.com/video/BV1iz421h7gb?spm_id_from=333.788.videopod.sections&vd_source=da862fa7a218e81897b55d7e24fe26ee

https://www.bilibili.com/video/BV1enQLYKEA5/?spm_id_from=333.1387.homepage.video_card.click&vd_source=da862fa7a218e81897b55d7e24fe26ee


四个模型

基准模型(ref_model) 训练模型(activate model) 奖励模型(reward model) 状态价值模型(state_value model)

其中训练模型和状态价值模型只有输出头不同,在代码里体现为:active_model 同时包含策略头(policy head)和状态价值头(value head)

image-20251028151952344

scores估算

batch_response = active_model.generate(batch_prompt)  #采样一次
batch_data = concat(batch_prompt, batch_response) #拼接prompt+result
batch_scores = reward_model(batch_data) #PPO的奖励模型,只输出seq_len的最后一个位置的score,其他位置为0
batch_all_probs, batch_probs, batch_all_values = active_model.forward_pass(batch_data)
ref_all_probs, ref_probs, ref_all_values = ref_model.forward_pass(batch_data)
kls = compute_KL(batch_all_probs, ref_all_probs)
rewards = compute_rewards(batch_scores, kls)  #eg. batch_scores+(-0.2)*kls

计算基准模型和训练模型的KL散度,并利用KL散度和scores计算rewards

score计算,即GRPO(Group Relative Policy Optimization)的主要创新,相比PPO不只采样一次,而是使用active_model采样多次,得到result与多个scores序列,然后对其进行标准化。

image-20251028151908583

GAE 广义优势估计:中和偏差与方差计算优势函数

image-20251028151926780

通过advantages和values相加计算values head labels即returns,让state_value model拟合这个returns值


一个batch训练阶段

对一个batch数据进行epoch次的更新,loss分别是loss_ppo和loss_state_value,更新active model

http://www.jsqmd.com/news/24675/

相关文章:

  • 【SPIE出版|EI检索稳定】2025年机电一体化与轨道交通国际学术会议(MRT 2025)
  • 脑电数据PCA处理及SVM分类
  • T671195 于凋亡季节中的我们
  • 2025年临沂营业执照注册推荐:华恒财税的专业选择
  • 2025 年盐城异常处理,盐城行业资质,盐城财务代账,盐城会计代账公司最新推荐,聚焦资质、案例、售后的五家公司深度解读
  • 如何在Windows下开发输入法:Mini How to
  • 2025 年 10 月盐城公司变更,盐城地址挂靠,盐城商标注册公司最新推荐,聚焦资质、案例、售后的五家公司深度解读
  • 第一天学习
  • AI元人文:星火与土壤
  • 5-4-其他查询 - 实践
  • K3s + Sysbox:让容器拥有“虚拟机的灵魂”
  • 题解:AT_abc200_e [ABC200E] Patisserie ABC 2
  • CF1996G Penacony
  • 远程命令执行漏洞、SSRF、XXE、tomcat弱口令漏洞
  • Ollama API 交互
  • 项目冷场?用禅道协作白板激活团队的创新思维!
  • xxx.ped 在生物信息学中是什么?
  • Ollama 基本概念
  • 2025年桥洞力学板市场趋势与选购指南:江苏同芯木业江苏行业领先
  • 2025年桥洞力学板行业发展趋势与前五厂家推荐
  • 2025年10月桥洞力学板品牌综合评测与行业趋势分析
  • 2.HD302-070 socket can调试笔记1
  • 如何使用FlareSolverr来抓取Cloudflare网站 - 狼人:
  • 吴恩达深度学习课程二: 改善深层神经网络 第一周:深度学习的实践(一)
  • 云端微信 - 随时随地在浏览器访问
  • Ollama 运行模型
  • 【往届EI、Scopus已检索|ACM独立出版】第二届经济数据分析与人工智能国际学术会议 (EDAI 2025)
  • win11后台程序cpu高占用问题
  • 线段树的各种姿势
  • 2025 年矿井轴流通风机,矿井抽出式轴流对旋通风机,矿井压入式对旋轴流通风机,FKD 系列矿井压入式对旋轴流通风机厂家最新推荐,实力品牌深度解析采购无忧之选