当前位置: 首页 > news >正文

【vLLM 学习】Rlhf

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf.py

# SPDX-License-Identifier: Apache-2.0""" 一个基于 vLLM 的 RLHF 简单实现演示,灵感来源于 OpenRLHF 框架 https://github.com/OpenRLHF/OpenRLHF 。 该设计采用训练进程(training processes)与推理进程(inference processes)分离的方案,它们运行在不同的 GPU 上。 训练进程向推理进程发送提示(prompts)以生成数据, 同时通过将模型权重从训练进程广播(broadcast)到推理进程 来实现模型权重的同步。 注意:本演示仅展示单个训练实例(training instance)和单个 推理实例(inference instance)的简单场景。 实际应用中可能存在多个训练实例和多个推理实例。 完整实现请参考 OpenRLHF 框架。"""importosimportrayimporttorch from ray.util.placement_groupimportplacement_group from ray.util.scheduling_strategiesimportPlacementGroupSchedulingStrategy from rlhf_utilsimportstateless_init_process_group from transformersimportAutoModelForCausalLM from vllmimportLLM, SamplingParams from vllm.utilsimportget_ip, get_open_port class MyLLM(LLM): def __init__(self, *args, **kwargs):# a hack to make the script work.# stop ray from manipulating CUDA_VISIBLE_DEVICES# at the top-level# 临时解决方案:确保脚本正常运行# 禁止 Ray 在顶层修改 CUDA_VISIBLE_DEVICES 环境变量os.environ.pop("CUDA_VISIBLE_DEVICES", None)super().__init__(*args, **kwargs)""" 开始训练过程,在这里我们使用 HuggingFace Transformer 作为在 GPU0上保存模型的示例。""" train_model=AutoModelForCausalLM.from_pretrained("facebook/opt-125m")train_model.to("cuda:0")""" 启动推理过程,我们使用 vLLM 在 GPU1和 GPU2。有关如何使用 ray 的详细信息, 请参考 ray 文档 https://docs.ray.io/en/latest/。""" os.environ["CUDA_VISIBLE_DEVICES"]="1,2"ray.init()pg_inference=placement_group([{"GPU":1,"CPU":0}]*2)ray.get(pg_inference.ready())scheduling_inference=PlacementGroupSchedulingStrategy(placement_group=pg_inference,placement_group_capture_child_tasks=True,placement_group_bundle_index=0,)""" 启动 vLLM 推理引擎。 在这里,我们使用`enforce_eager`减少开始时间。""" llm=ray.remote(num_cpus=0,num_gpus=0,scheduling_strategy=scheduling_inference,)(MyLLM).remote(model="facebook/opt-125m",enforce_eager=True,worker_extension_cls="rlhf_utils.WorkerExtension",tensor_parallel_size=2,distributed_executor_backend="ray",)# 从提示中生成文本。prompts=["Hello, my name is","The president of the United States is","The capital of France is","The future of AI is",]sampling_params=SamplingParams(temperature=0)outputs=ray.get(llm.generate.remote(prompts, sampling_params))foroutputinoutputs: prompt=output.prompt generated_text=output.outputs[0].text print(f"Prompt: {prompt!r}, "f"Generated text: {generated_text!r}")# 设置训练进程与推理引擎之间的通信master_address=get_ip()master_port=get_open_port()handle=llm.collective_rpc.remote("init_weight_update_group",args=(master_address, master_port,1,3))model_update_group=stateless_init_process_group(master_address, master_port,0,3, torch.device("cuda:0"))ray.get(handle)# 模拟训练,修改模型的权重。forname, pintrain_model.named_parameters(): p.data.zero_()# 同步从训练过程到推理引擎的权重。forname, pintrain_model.named_parameters(): handle=llm.collective_rpc.remote("update_weight",args=(name, p.dtype, p.shape))model_update_group.broadcast(p,src=0,stream=torch.cuda.current_stream())ray.get(handle)# 检查权重是否更新。assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))# 使用更新的模型生成文本,它们会胡说八道# 因为权重都是零。outputs_updated=ray.get(llm.generate.remote(prompts, sampling_params))foroutputinoutputs_updated: prompt=output.prompt generated_text=output.outputs[0].text print(f"Prompt: {prompt!r}, "f"Generated text: {generated_text!r}")
http://www.jsqmd.com/news/211121/

相关文章:

  • 【光子AI / Photon AI】整理2021~2026 在 AI Agent、Multi-Agent Systems、多智能体学习、多智能体强化学习、协同智能/代理型智能体 等方向的 Papers
  • 枚举类型:常量集合的优雅管理
  • 无人值守智能污水处理控制系统:威纶通触摸屏与西门子PLC协同运行,真实工程项目稳定运行一年多供...
  • Demo 骗了所有人?一做就会,一用就废!多模态 RAG 跨不过去的这道坎,看透了!
  • 通过合理建模与架构设计,90% 的“JOIN 需求”可转化为 ES 原生支持的高效查询。
  • ‌测试教育路径:大学课程 vs 自学——2026年软件测试从业者专业成长指南
  • 90%的程序员都在错误选择Embedding模型!6步评估框架+代码实战,让你避开所有坑,小白也能秒变向量专家!
  • 基于遗传算法优化的VMD信号去噪算法:样本熵与信噪比双重适应度函数提升信噪比及故障诊断特征提取研究
  • 美国地产交易被AI大模型颠覆,RAG+混合搜索效率提升40%,程序员都在学!
  • 测试人员压力管理:构建可持续的截止日期应对框架——面向软件质量守护者的专业生存指南
  • S32K144 Bootloader开发实战:CAN与串口双剑合璧
  • GRBL三轴在STM32F103C8T6上的移植与脱机运行控制指南:源码资料打包,含OLED屏...
  • 硕士论文过审第一步:paperzz 论文查重功能,怎么帮你避开重复率雷区?
  • MATLAB四旋翼仿真中的滑模控制、反步控制与PID控制方法及公式文献参考
  • IP5385至为芯支持C口双向快充的30W到100W移动电源方案芯片
  • 【Linux命令大全】003.文档编辑之pico命令(实操篇)
  • 生活电器:重塑日常的科技力量
  • WordPress数据可视化插件定制开发最佳公司
  • 深度探索无线充电黑科技:LCL-S拓扑结构的那些事儿
  • 学服务器训练AI模型:5步路径助力高效入门
  • 罗德与施瓦茨HMP4040 HMP4030可编程直流电源四通道
  • 基于STM32的智能红绿灯控制系统
  • Delta 台达PLC-EH3铆压机程序:3轴控制方案详解及电气设计(含MODBUS通讯、伺服...
  • 今日头条视频下载方法汇总 高清无水印 (2026 最新实测)
  • adb.exe logcatadb.exe: command not found
  • 【Linux命令大全】003.文档编辑之rgrep命令(实操篇)
  • JavaScript 中 async + await 和直接同步方式执行有什么区别和意义
  • 全球实验室耗材市场:技术驱动下的区域竞争与未来增长图谱
  • 【Linux命令大全】003.文档编辑之sed命令(实操篇)
  • 2026全新版Java面试八股文.pdf出炉, 简直把所有 Java 知识面试题写出来了