当前位置: 首页 > news >正文

大模型如何训练百万 Token 上下文:上下文并行与 Ring Attention

只用了几年时间,上下文窗口就从 4k 膨胀到 1000 万。Meta 发布的 Llama 4 Scout 的时候说这个模型支持 1000 万 Token,是 Llama 3 那 128k 的 78 倍。而Google Gemini 3 Pro 是 100 万,Claude 4 也桐乡市100万。

一次推理跑完整个代码库、几百篇论文、连续好几天的对话记录在技术上可行了,但问题是硬件跟不上。

405B 参数的模型,32 位精度下光权重就要 6.5TB 内存。再算上梯度、状态、激活值,后者还随上下文长度二次方增长。单台 NVIDIA HGX B300 配了 2.3TB HBM3e都不够。

这就逼着必须做多节点分布式训练和推理,几十上百块 NVIDIA Blackwell GPU 、NVLink 再加上 InfiniBand,就成了数据中心的标配。所以难点就变味了 GPU 之间的通信瓶颈。

并行化基础

模型或数据集超出单卡容量,就得上并行策略,但是每种策略本质上都是拿通信开销换内存空间。

数据并行是最直接的方案:整个模型复制到每张卡上,训练数据切开,每张卡跑不同的 batch跑完一步同步梯度。适合小模型,计算是瓶颈、内存不是问题的场景。

模型并行针对大模型:单卡装不下,就把模型拆开,不同的层放不同的卡上,按顺序跑。405B 这种规模只能这样,并且下游的卡得等上游算完中间是有空转的。

张量并行更极端:连单个矩阵乘法都塞不进一张卡。就需要把矩阵按行或按列切开,分到各卡上算,再通过 all-reduce 合起来。

但这些都有共同的局限。模型大、上下文又长到几百万 Token,张量并行也顶不住。因为注意力的二次方内存增长太凶,激活值直接占满显存。128k 上下文的激活值内存是 8k 的 16 倍,这个目前没办法,因为就是这么夸张。

上下文并行与序列并行

序列并行和上下文并行都是在设备间切序列来省内存,但切法不一样。

序列并行配合张量并行使用,只切那些非矩阵乘法的操作,比如层归一化、dropout。张量并行管不到的地方,序列并行接手,每张卡处理一部分激活值。两者配合能把序列撑长一些,但到 128k 以上还是会有问题,因为注意力的二次方增长是绕不过去。

上下文并行更彻底:整个序列在所有模块里都切开,包括注意力。每个操作拿到的都是分区后的序列。百万级上下文的训练就靠这个,把激活值的内存占用分摊到各卡上。

注意力一直是最麻烦的问题,因为模型的其他操作基本都是逐 Token 独立处理并行起来很自然。但注意力不行,每个 Token 都要"看"序列里所有其他 Token。序列切到多张卡上之后,GPU 1 的 Token 怎么看 GPU 2 的 Token?直接等数据传完再算,整个流水线就卡住了。

Ring Attention 就是来解决这个问题的,让多节点多卡的大模型训练和推理能在大规模数据中心里跑起来。

Zig Zag Ring Attention:通信和计算重叠

Ring Attention 把 GPU 组织成环形拓扑。每张卡的工作流程是这样的:持有序列中 Q、K、V 张量的一个分块;用本地的 K 和 V 给自己的 Q 分块算注意力;把 K 和 V 传给环里的下一张卡;从上一张卡接收 K 和 V;循环往复,直到所有 Q Token 都跟所有 K/V Token 算完注意力。

关键在于计算和通信是重叠的。GPU 1 拿着当前的 K/V 分块算注意力的时候,同时在从 GPU 0 接收下一批分块。通信延迟减少了,因为不用干等数据全到了再开算。

GPT 这类自回归模型有个额外的麻烦:Token 只能看前面的 Token不能看后面的。所以会导致负载不均衡有些卡会空转,Zig-Zag Ring Attention 解决这个问题的办法是交错分配,不是按顺序切块而是 GPU 0 拿 Token [0, 4, 8…],GPU 1 拿 [1, 5, 9…],以此类推。每张卡都拿到早期和晚期 Token 的混合,因果注意力计算时负载就均衡了环里不会有卡闲着。

但是代价是索引逻辑稍微复杂一点,不过大规模场景下性能收益很可观,因果掩码下也能做到接近满 GPU 利用率。

上下文并行与 Ring Attention 常见问题

上下文并行把输入序列切到多张 GPU 上,突破训练时的内存限制。跟张量并行、数据并行不同,它在所有模型模块里都切序列维度。单卡装不下的百万级 Token 上下文,只有靠这个才能训。

Ring Attention 把 GPU 排成环,每张卡一边算当前数据的注意力,一边把键值对往下传。通信和计算重叠,全对全的注意力计算不用等完整序列数据到齐,GPU 不会干等。

而序列并行只切非矩阵乘法操作(层归一化之类的),配合张量并行用。上下文并行在所有模块里都切序列,包括注意力。超过 128k Token 的上下文必须用后者,因为激活值内存二次方增长太猛了。

为什么 Zig-Zag Ring Attention 比标准 Ring Attention 更好?

Zig-Zag 用交错分配代替顺序分配,因果掩码计算时各卡负载更均衡。标准 Ring Attention 会让后面的卡等前面的分块,造成计算空闲。Zig-Zag 把早期和晚期 Token 均匀撒到各卡上,避免这个问题。

那么训练百万级 Token 上下文的模型需要什么硬件?

多节点 GPU 集群,配 HBM 内存,加高速互连——NVIDIA NVLink 1.8TB/s 或者 InfiniBand。405B 参数模型 32 位精度从头训练加推理,4 台 NVIDIA HGX B300 的机架部署是个不错的起点。

总结

上下文并行本质上是拿通信开销换内存空间,而网络带宽是最要命的瓶颈。Ring Attention 要在 GPU 之间不停交换键值对,传输时间一旦超过计算时间,各卡就会从"边算边传"退化成"等数据"。NVIDIA NVLink 1.8TB/s 加 InfiniBand 的高速互连,在多机架部署里不是可选项是必需品。互连带宽必须匹配 GPU 计算吞吐量,否则上下文并行的效果会大打折扣。

https://avoid.overfit.cn/post/fd6022b9196942ffb737ba306925b6db

by Khang Pham

http://www.jsqmd.com/news/323174/

相关文章:

  • 【计算机毕业设计案例】基于springboot的t智慧驾培综合服务管理平台学车驾校管理系统(程序+文档+讲解+定制)
  • 超越Python:下一步该学什么编程语言?
  • 【游戏推荐】超市模拟器 PC手机双端 送修改器(Supermarket Simulator)免安装中文版
  • C++与Kubernetes集成
  • 社会网络仿真软件:NodeXL_(8).网络属性计算:度中心性、介数中心性、接近中心性
  • 【计算机毕业设计案例】基于springboot+vue的废旧品线上回收系统旧物回收管理系统(程序+文档+讲解+定制)
  • Python项目结构:如何组织你的代码
  • 【WTMSVM诊断网络】基于小波多尺度同步压缩变换WMSST结合MCNN-SVM多尺度卷积神经网络和支持向量机的故障诊断研究附Matlab代码
  • 写给新手的Python代码风格规范(PEP 8)
  • 使用Flask快速搭建轻量级Web应用
  • 【游戏推荐】绝世好武功 免安装中文版
  • C++代码冗余消除
  • 【游戏推荐】工业的崛起2 全DLC(Rise of Industry 2)免安装中文版
  • 移动设备上的C++优化
  • 缓存读写代码逻辑的正确姿势
  • Trust is All You Need | 2025通付盾智能体安全进展盘点
  • 洛谷P9869 [NOIP2023] 三值逻辑 题解
  • 一、C++简介与环境配置
  • 【游戏推荐】NBA 2K 欢乐竞技场2 (NBA 2K Playgrounds 2)免安装中文版
  • 金融领域元学习在模型快速适应中的应用
  • 模板元编程调试方法
  • 基于单片机的无线病床呼叫系统设计
  • Python日志记录(Logging)最佳实践
  • 大语言模型微调数据对齐五大核心算法SFT、RLHF、DPO、PPO、GRPO
  • AI Agent在预测分析中的应用
  • 2026年AIR SCI1区TOP,基于三维 Rényi 熵模型的多特征融合与量子混合算法+阿尔茨海默病脑图像分割,深度解析+性能实测
  • C++中的适配器模式变体
  • 5种落地性最强的对齐微调数据集格式
  • GPU thread 概念
  • 大数据清洗:提高数据质量的10个实用技巧