大模型“睡眠”机制:提升推理能力,训练成本却线性增长?
1. 长上下文困境
很长一段时间,「长上下文」是各大模型厂商军备竞赛焦点,从 128K 到 1M,再到更长上下文窗口。业界认为窗口足够大,模型就能记住更多内容、处理更复杂任务。但问题也随之而来,上下文越长,KV Cache 越臃肿,导致显存被「吃光」、推理速度缓慢、成本上升。而且把更多 token 放进窗口,不代表模型能将信息转化为可推理的长期记忆,在复杂推理任务中,模型常因「记不住细节」而翻车。
2. 新视角:语言模型需要睡眠
近日,卡内基梅隆大学(CMU)联合马里兰大学等在新论文《Language Models Need Sleep》中提出有意思的视角,让 LLM「睡一觉」。这里的「睡眠」是一种类似睡眠的「记忆巩固机制」。作者认为基于 Transformer 的大语言模型用于长程任务时,注意力机制扩展性较差,为此研究出该机制。在睡眠过程中,模型对累积的上下文执行 N 次离线递归前向传播,通过学习得到的局部规则更新状态空间模型(SSM)模块中的快速权重(fast weights)。推理阶段,这种方法把额外计算转移到「睡眠」阶段,保持模型「醒着」预测时的延迟不变。
3. 从动物睡眠获得启发
论文灵感来自动物睡眠中的记忆巩固过程。神经科学研究认为,动物从短期记忆到长期记忆的转移受海马体 replay 机制支持,尤其在睡眠期间,短期海马体记忆会被重新激活并巩固到皮层突触权重中。基于此,作者提出把上下文窗口记忆转移到持久权重中的方法。当模型上下文窗口在推理过程中被填满,模型进入「睡眠」状态,对累积的上下文执行多次前向传播,通过学习得到的局部规则递归更新 fast weights,此阶段模型不接收外部输入 token。巩固完成后,上下文窗口清空,模型带着更新后的 fast weights 继续运行。训练过程中,模型通过整个过程的反向传播进行端到端优化,以最大化睡眠之后的任务表现。大模型训练过程分为「醒着」和「睡眠」两个阶段。「醒着」阶段,模型像普通 Transformer 一样正常工作,接收长文本输入,快速给出预测和回复;「睡眠」阶段,模型进入「离线睡眠状态」,对积累的上下文进行 N 次循环往复的离线处理,将近期上下文中的关键细节转化为持久的 fast weights 并写入 SSM 模块。
4. 实验:睡得越久,推理越强?
为验证增加睡眠时 N 能否提升模型对「旧」上下文的推理能力,作者进行系列实验。以更接近自然语言的数学推理任务 GSM - Infinite 为例,它通过添加干扰 token 拉长题目,用所需算术操作数控制难度。作者在 Jet - Nemotron 2B 和 Ouro 1.4B 两个预训练模型上测试模型的「睡眠」机制。结果显示,题目越难,「睡眠」带来的提升越明显。对于 Jet - Nemotron 2B,6 次 sleep loop 将 6 步运算题准确率从 0.742 提升到 0.812,将 8 步运算题从 0.351 提升到 0.388;对于 Ouro 1.4B,4 次 sleep loop 将 6 步运算题准确率从 0.419 提升到 0.615,将 8 步运算题从 0.210 提升到 0.272。「睡眠」机制对简单题帮助相对不明显,在复杂任务中,「睡眠」阶段的额外计算开始发挥作用。
5. 局限性:效果明显,代价同样明显
作者坦言,这种方法通过把额外递归计算转移到巩固阶段,保持了预测阶段的单次前向传播延迟,但收益并非免费。训练过程中,需要执行 N 次更深的前向和反向传播,会让训练变慢且可能不稳定。执行 N 次带来效果明显提升,但训练成本也随其线性增长。这项工作目前主要是方法论探索,该方法主要贡献在方法论层面,评估基于受控合成任务和中等规模预训练模型,还不是在超大规模商用模型、真实长程 Agent 系统中充分验证的成熟方案。
