当前位置: 首页 > news >正文

【vLLM 学习】Rlhf Utils

vLLM 是一款专为大语言模型推理加速而设计的框架,实现了 KV 缓存内存几乎零浪费,解决了内存管理瓶颈问题。

更多 vLLM 中文文档及教程可访问 →https://vllm.hyper.ai/

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/rlhf_utils.py

import torch def stateless_init_process_group(master_address, master_port, rank, world_size, device): """ vLLM 提供 `StatelessProcessGroup` 来创建进程组, 无需考虑 torch.distributed 中的全局进程组。 建议先创建 `StatelessProcessGroup`,然后初始化 外部(训练进程)与 vLLM 工作进程之间的数据平面通信(NCCL)。 """ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup pg = StatelessProcessGroup.create(host=master_address, port=master_port, rank=rank, world_size=world_size) pynccl = PyNcclCommunicator(pg, device=device) return pynccl class WorkerExtension: """ vLLM 工作进程的基类。 通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。 这种方式使代码能同时兼容 vLLM V0 和 V1。 注意:我们在单独模块中定义此类,主模块应将完整限定名 作为 `worker_extension_cls` 参数传递。 """ def init_weight_update_group(self, master_address, master_port, rank_offset, world_size): from vllm.distributed.parallel_state import get_world_group rank = get_world_group().rank + rank_offset self.model_update_group = stateless_init_process_group( master_address, master_port, rank, world_size, self.device, ) def update_weight(self, name, dtype, shape): weight = torch.empty(shape, dtype=dtype, device="cuda") self.model_update_group.broadcast(weight, src=0, stream=torch.cuda.current_stream()) self.model_runner.model.load_weights(weights=[(name, weight)]) del weight def check_weights_changed(self): """ Check if the weights are updated to 0. """ """ 检查权重是否已更新为 0。 """ weights_updated = True for name, p in self.model_runner.model.named_parameters(): weights_updated = weights_updated and torch.allclose( p, torch.zeros_like(p)) return weights_updated class ColocateWorkerExtension: """ vLLM 工作进程在协同部署场景下的基类。 通过定义扩展类,无论底层工作进程类是什么,代码都能正常工作。 这种方式使代码能同时兼容 vLLM V0 和 V1。 注意:我们在单独模块中定义此类,主模块应将完整限定名 作为 `worker_extension_cls` 参数传递。 """ def report_device_id(self) -> str: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid def update_weights_from_ipc_handles(self, ipc_handles): handles = ipc_handles[self.device_uuid] device_id = self.device.index weights = [] for name, handle in handles.items(): func, args = handle list_args = list(args) # the key is to change device id to the current device id # in case two processes have different CUDA_VISIBLE_DEVICES # 关键是将设备 ID 改为当前设备 ID, # 以防两个进程有不同的 CUDA_VISIBLE_DEVICES list_args[6] = device_id tensor = func(*list_args) weights.append((name, tensor)) self.model_runner.model.load_weights(weights=weights) torch.cuda.synchronize() def check_weights_changed(self): """ 检查权重是否已更新为0。 """ weights_updated = True for name, p in self.model_runner.model.named_parameters(): weights_updated = weights_updated and torch.allclose( p, torch.zeros_like(p)) return weights_updated
http://www.jsqmd.com/news/268931/

相关文章:

  • Day25-ComfyUi环境搭建
  • Golang原理剖析(defer、defer面试与分析)
  • 攻防世界backup
  • gitee分支
  • Manus官方揭秘Sandbox云计算机:智能体的云端 AI 助手与智能计算环境
  • 炒股别太努力:量化交易正在“收割”最认真的投资者?
  • 手把手搭建本地RAG知识库!实现文档秒检索。
  • VP引导定位软件-定位纠偏(带角度)
  • PL3327系列(PL3327CD/CS/CE/CF) 18W AC/DC反激式开关电源芯片方案
  • 使用MCP执行代码:让Agent效率提升98.7%
  • 具备这5大潜质的人,天生就是卖货王者
  • 基于YOLOv8的交通事故车辆损伤检测与事故严重程度分级项目识别项目
  • Uniapp苹果内购支付全流程指南:从集成到配置的完整复盘
  • 哈尔滨特色美食口碑大赏!对青烤鹅力断层领先,成游客必打卡爆款 - 资讯焦点
  • 【数据分析】基于matlab辅导功能和ISSR-MDF模型的综合预警指标【含Matlab源码 14993期】
  • git 如何切换到123分支?
  • 深入解析:基于非官方接口的企业微信外部群批量创建与效率重构
  • 如何一次提交,提交到两个分支上?
  • 小国护照热度不减:2025年-2026年移民市场服务模式观察 - 资讯焦点
  • 【数据分析】辅导功能和ISSR-MDF模型的综合预警指标【含Matlab源码 14993期】
  • ‌AI测试覆盖率提升秘籍:从70%到95%的跨越‌
  • 移民市场深度观察:如何在信息洪流中甄选可靠的移民机构 - 资讯焦点
  • 【数字信号去噪】吕佩尔狐算法优化变分模态分解RFO-VMD数字信号去噪(优化K值 alpha值 综合指标 适应度函数包络熵)【含Matlab源码 14994期】
  • LLMs、RAG、AI Agent三个到底什么区别?
  • 2026年主流云游戏平台深度横评:硬件架构、网络性能与定价策略,谁是全能王者? - 资讯焦点
  • Vivado安装失败原因分析与修复方法汇总
  • UDS诊断服务(ISO 14229-1)
  • 2026气体检测仪优质厂家排行榜 实力之选 - 资讯焦点
  • 如何利用工厂大脑提升汽车制造的质量与效率?
  • 工业AI智能体如何提升汽车制造效率与良率?