当前位置: 首页 > news >正文

混合注意力学习(1): 线性注意力

Prefill、Decode与KVCache

在开始本文之前,首先应该介绍一下什么是prefill,什么是decode,以及对应的KVCache。这样可以更好理解内存复杂度。
小学生们都日益了解了,现有的LLM大语言模型的组成成分主要是Transformer Block。其中,注意力机制具有如下的计算公式:

在推理过程中,我们采用 decoder-only 架构,因此注意力机制为“掩码自注意力”。流程如下:

从上面的流程图中,我们看到Key和Value是复用的,即每一次新的计算都需要用到先前的输入生成的Key和Value。因此我们将其称为KVCache。我们可以发现,每一个Transformer Block的KVCache都是随着序列长度线性增长的。因此整体的空间复杂度为 (假定)。

因此我们的prefill和decode流程简化来说是如下所示的,prefill要求多个seq并行进入进行计算,decode则每次接收一个上次生成的token进行计算。重要的两个指标为:TTFT(Prefill开始到第一个token生成所需要的时间),TPOT(每个token生成之间的时间)[1]。

[1:1]

[1:2]

同时我们也可以用roofline模型来刻画我们的序列长度,请求次数导致的计算上限和存储上限。对于短序列prefill和decode阶段,主要是存储受限(计算强度低,但是需要大量读写KVCache)。对于长序列prefill阶段,主要是计算受限(计算强度高,GEMM)

这样我们就介绍完了KVCache,Prefill和Decode,了解了他们的不同和受限情况。

混合注意力架构

混合注意力架构包括了稀疏注意力与线性注意力,而线性注意力机制又起源于这篇Transformers are RNNs的文章。我们都大致介绍一下作为我们的基本背景。

线性注意力

这里也可以阅读苏神的文章[2][3]。我们总结了一张图如下:

Transformers 就是 RNNs[4]

在上文中我们已经提到了经典的自注意力机制的计算,我们这里再展示一次:

对于上述式子,我们很容易想当然地采用近似函数来代替。我们首先分析我们的矩阵运算结果如下:

接下来我们就可以尝试用近似函数代替,也就是 :

这里距离我们的线性注意力还有一段距离,但是已经不远了。我们回想在古老的分离向量机中,为了实现非线性支持向量机,我们需要使用到核函数技巧。

  • 为了将非线性问题转换成线性问题,我们采用核函数技巧。对于输入空间 , 为特征空间(希尔伯特空间),如果存在映射函数 满足 , 则 为核函数, 为映射函数。[5]
  • 在注意力机制中,我们只需要要求函数非负来表示其概率性。对此可以参见[6]

因此,我们采用核函数技巧进行分解,就可以得到如下的公式:

很容易注意到我们可以采用结合律来提取出公因式 。注意到 ,采用结合律需要转置。因此我们有:

这样我们很容易可以看出:原先softmax的计算复杂度是 ,随着序列以 的平方复杂度增长。而核函数在维持原有隐维度的条件下保持 的线性复杂度增长,并且隐维度 仍然具有可优化的空间。这就是线性注意力的由来。

这样,如果我们考虑到推理模式下的decode-only掩码自注意力场景下,我们不会计算全部的序列 ,而是计算到当前序列 ,这样我们就有:

我们令 ,就可以得到:

很明显就是我们对应的最简单的RNN架构。这也说明了:(1) Transformer其实是大号RNN;(2)线性注意力在理论层面是可行的。大量的实验发现分母会导致严重的数值不稳定问题,并且可以无需映射函数直接采用分子参与计算。[7]。这样实际上为 .

我们最后再看看梯度的计算。注意在训练场景下我们是全序列,因此我们在给定分子 和损失函数 的条件下,参考[8]处的运算,我们有:

这样我们就把所有的梯度都计算出来了。

小贴士:在参考原有的链式法则基础上,适当通过各种转置方法保证梯度张量和参数张量保持一致(因为梯度张量应该和参数张量相同,这样才能利用梯度下降更新)可以提升我们的计算速度。

Fast Weight Programmers与DeltaNet[9]

线性注意力允许我们使用如下的方式更新当前状态模仿RNN:

这样对于一个长为 的序列:我们每次计算步数为 ,计算复杂度为 ,训练过程中需要的空间复杂度为 ,推理过程中需要的时间复杂度为 .

并且,在这篇论文[9:1]中, 实际上是一个关联性的内存,存储了当前瞬态从key到Value的映射。这样的更新可以看作是一个无上界的关联损失函数的梯度下降,从而持续强化最近的键值对,没有任何遗忘。也就是文章[7:1]中说的:

这样的持续无遗忘将会在长上下文中造成严重的干扰。(我们的人脑也会通过忘记无关紧要的,久远的非必要记忆来保证我们对当前上下文的专注)。

我的某位朋友指出我需要提供为什么梯度更新和上面的线性注意力更新是等价的。在此给出确切的证明。

对于此,[9:2] [10]提出了如下的更新方式,也称为Delta Rule。

  • 新的 到来。
  • (Read)取出上一次的 ,构造未更新前的前Key-我们看到Key和Value关联模式:
  • 通过一个学习率网络 来构造动态学习率 , 是激励函数
  • 通过学习率控制K-V关联性:。
  • 更新状态矩阵实现(Write)遗忘。

这样也等价于重建了一个无上界的Loss重建成如下形式:

通过学习系数 来单步梯度下降,修正自己当前时间步下的记忆关联 。这样的变换允许了硬件通过分块并行来提升计算速度。这篇文章中详细说明了针对线性注意力的高效并行[7:2]。

证明如下:

然后通过 来更新我们的权重。更新的公式则为:

这就是DeltaNet提出的重建损失函数——来实现一定的遗忘性。 也被称为 Delta 系数。

Gated DeltaNet[11]

为了进一步减少历史记忆和状态对现有的影响,Mamba2进一步通过权重衰减来实现对过去的遗忘:

更进一步,结合Delta规则实现遗忘:

这样就等价于通过如下的损失函数进行梯度下降。

因此 成为门控权重衰减系数(gating weight decay),这就是Gated Delta Network的具体形式。

线性注意力并行机制[7:3]

传统FWP形式的线性注意力并行

线性注意力在时间迭代的形式如下:

我们对比一下线性注意力的并行形式与迭代形式。针对并行形式,我们将 堆叠在一起形成一个整体,这样就形成了 。注意堆叠的方式是:, 均采用如同 的堆叠方式。这样我们就拥有了如下的计算公式:

而最终计算结果需要要求查询不能看到未来的键和值(不然就成透视未来了~),因此我们加入一个下三角掩码即可。这样最终输出就应该是:

我们比较一下并行算法和迭代算法的不同。在复杂度计算中,我们假设 ,这样更直观一些。我们每一层Transformer Block的复杂度如下:

算法时间复杂度空间复杂度计算步数
时间步迭代推理,训练
并行计算

诶?为什么并行计算方式的时间复杂度更高,但是执行时间更低呢?不要忘记并行计算的优势在于同时计算速度快,瓶颈因素转移到了计算步数上。对于时间迭代算法,我们无法充分发挥并行计算的优势。因此在长序列上很明显并行算法在计算步数上远小于迭代算法。在GPU上还可以充分利用tensorCore等用于GEMM的优势,时间步迭代则不行。但是并行算法内存占用很高,我们可以看到空间复杂度呈平方增长,这又失去了线性注意力的优势。

为了实现高效的计算,分块并行就成为了一个权衡两者的利弊的一个有效方式,这样可以充分利用计算资源的同时降低内存占用。

FWP分块并行机制

首先我们规定对应的符号。

  • 代表第 个分块。这个分块通过 的方式堆叠。
  • 代表 “第个分块中的第个列向量 ”。
  • 堆叠里面有个向量,是因为我们还规定 .
  • 上面的记号对 都成立。
  • ,代表第t个分块中的第 个状态矩阵。

这样,我们就可以改写我们的迭代步骤成混合形式。块内的某一个元素 则为:

这样对整一个块我们有如下公式,注意查询不能看到未来的key和value~:

这样我们就实现了分块并行的线性注意力策略。对于每一层Transformer block,计算步数则为 ,内存复杂度变为 ,计算复杂度则为 ,再次回到了线性注意力的计算复杂度!空间复杂度上,空间复杂度为 ,训练情况下则需要保存每一个chunk为,也维持了线性的增长!

DeltaNet 分块并行机制

DeltaNet 的更新公式如下:

其中我们有:

最直观的方式就是将后面的一部分重新表示成一个新的向量 ,并且将 吸收进去,可以得到如下公式:

从而得到整体的计算公式。这样,通过上一节的堆叠和掩码方式得到并行计算的方式,也就是

但是,这样的表示真的正确吗?我们存在如下的问题:(1) 需要上一个状态的直接计算,导致实际上无法针对 进行并行计算。(2) 计算每个 都需要上一个状态的状态矩阵 ,导致内存占用从 上升到 。

重新定义 降低内存占用

但是这同时也带来了新的问题:我们的空间复杂度从 上升到了 。回顾我们的导出过程,以及上面的公式,我们可以得到:

这就意味着我们在计算矩阵 的时候,总是需要保证我们至少存取了上一次的关联矩阵 ,这样我们的实际的内存复杂度就应该为 ()。对此,我们需要重新定义 ,不再存储过去的状态矩阵。通过数学归纳法来得到新的 。

回顾我们的导出过程,我们有:

假设 ,. 这样我们有归纳起始条件:,. 如果有 ,则有:

这样 将不再需要读取上一次的状态矩阵 ,每次计算只需要存取 . 此时的内存复杂度再次回到曾经的 。

实际上上面使用的数学归纳法的灵感来源于如下的矩阵计算相关:
HouseHolder变换和WY表示[12]
HouseHolder 变换:对于一个非零向量 ,如果一个矩阵 满足

则这个矩阵称为 HouseHolder 矩阵。 称为 HouseHolder 向量。对于一个向量 , 称为HouseHolder变换。
我们很容易发现 HouseHolder 变换是一个 rank-1 的修正。因为非零向量外积的秩永远为1,所有的列都落在 中。
WY表示:假设 是一个 rank-r 的修正,这样我们有 , ,.
证明:采用数学归纳法。我们假定 . 因此我们有:

这样很明显我们的WY表示是成立的。

但是,如同先前需要实现并行或者分块并行的理由相同——GPU等加速器更适合并行矩阵运算,时间步数 的算法不适合在对应硬件上实现。因此我们需要实现DeltaNet的分块并行运算

DeltaNet的分块并行

针对我们的DeltaNet的分块并行算法,我们需要做一系列比较复杂的变换。我们首先将原本的公式表示成如下所示。

我们需要充分利用广义householder变换的性质。因此我们定义如下:

因此我们重写 。通过循环迭代我们得到:

接着我们定义分块矩阵相关符号。

  • ,代表第t个分块中的第 个状态矩阵并且。
  • . 这对 同样适用。

这就是我们原文的初始分块模式。但是存储 和 在训练/prefill过程中需要 的存储空间(假设 ),我们可以通过类似上面的WY表示的数学归纳法来降低内存占用到 。

我们需要分析一下 和 。 我们将他们展开,可以得到:

http://www.jsqmd.com/news/1080573/

相关文章:

  • 魔兽争霸3辅助工具终极指南:5分钟解决所有兼容性问题
  • FDD大规模MIMO中鲁棒反向注水算法:应对CSI反馈挑战的工程实践
  • SQLServer RAG笔记5:为SQLServer 2025配置Ollama
  • 电池寿命预测的AI革命:微软开源工具BatteryML深度解析
  • 日志管理化技术中的日志收集日志分析日志存储
  • 游戏网络同步:状态同步与帧同步的选择与实现
  • DarkHole2靶场渗透实战:从信息收集到权限提升的完整路径解析
  • 嵌入式处理器选型实战:从以太网与硬件加密需求到MCF5475应用解析
  • 流式计算架构设计
  • 绝地求生压枪宏:用Lua脚本实现罗技鼠标精准后坐力控制的完整指南
  • Java CompletableFuture 并发性能优化
  • LangChain链式提示工程实战:从Rap生成器解剖AI工作流
  • Java网络编程NIO与Netty框架
  • 中科蓝讯音频SoC开发实战:从芯片选型到量产问题排查
  • 什么是基于文件的应用
  • 南宁青秀区跑了几家店,这家体验最舒服
  • AI编排实战:MuleSoft+LangChain双引擎企业级集成指南
  • 空中交通终端区进场排序优化:FOFFS与CPS策略的实时性能对比分析
  • 虚拟机DNS解析失败:systemd-resolved与127.0.0.53:53错误深度解析
  • AI文本分块实战指南:16种生产级策略与避坑方法
  • Python 异步爬虫限速方案
  • 前端组件库设计实现指南
  • Spielman猜想:正则图成立与一般图反例的谱图论解析
  • 专业视频对比工具全面指南:高效分析视频质量差异的终极方案
  • Python量化交易数据获取终极指南:用efinance轻松搞定四大金融市场数据
  • 直击痛点型:PLM、ERP、MES买齐了,但你的智能制造真的100%落地了吗?
  • 基于Spdlog + Qt的日志显示框架设计与实现
  • 快速掌握Apache Spark:从入门到实战的完整指南
  • VMware与Hyper-V冲突排查手册(2024版):从设备管理器异常驱动到WDDM GPU虚拟化抢占,覆盖12类真实产线案例
  • 3分钟完成FF14国际服中文汉化:开源工具让语言不再是障碍