当前位置: 首页 > news >正文

FlashAttention(1):从内存模型到 Online Softmax

在 “Attention Is All You Need” 这篇论文中,我们了解了注意力机制的三个关键矩阵:Q (Query)、K (Key) 和 V (Value)。
在标准的注意力计算流程中,核心步骤包括计算注意力得分矩阵 \(S=QK^T\) 以及 Softmax 后的概率矩阵 \(P\)。当序列长度 \(R\) 很大时,这些 中间结果(\(S\)\(P\) 矩阵) 的维度高达 \((R,R)\),需要占用巨大的显存空间。
传统的实现方式必须频繁地在 GPU 高速缓存 (SRAM)GPU 显存 (HBM) 之间读写这些巨大的中间矩阵。当矩阵规模扩大时,GPU 的计算能力实际上被 显存带宽(Memory Bandwidth) 限制,导致 GPU 单元大部分时间处于空闲等待状态,无法充分发挥其计算潜能。
为了解决这一瓶颈,FlashAttention 论文提出了一种方法。该方法不同于传统计算方法,设计该方法时是基于 IO 带宽受限这一前提,利用 Tiling(分块) 技术计算输出矩阵 \(O\),并在 SRAM不需要中间结果 \(S\)\(P\) 矩阵就能完成 Softmax 归一化。这种方法能够极大程度地降低对 HBM 的访问次数,从而显著提升 GPU 的计算效率和速度。

硬件性能分析

GPU 存储层次

跟 CPU 存储层次类似,在 GPU 中也存在不同的存储介质,不同的介质之间的访问速度以及容量有着巨大的差距,这里不再过多赘述:

两种计算操作

计算密集:程序运行时,GPU 处于满负荷工作状态的时间占据程序运行的绝大部分,这类程序主要受限于计算吞吐量,不会频繁大量的访问显存。典型应用为大矩阵乘法和多通道卷积计算。
访存密集:程序运行时需要频繁大量的访问显存,GPU 出现频繁的空等,导致 GPU 的计算资源被浪费,典型应用为如元素操作(激活函数、Dropout)以及归约化操作(求和、softmax、批归一化、层归一化)。

传统 Attention 实现

Attention 计算有三个关键矩阵,分别是 \(Q\), \(K\), \(V\) \(\in \mathbb{R}^{N \times d}\) ,其中 \(N\) 是输入序列的长度,\(d\) 是头维度,我们想使用以上三个矩阵计算出输出矩阵 \(O \in \mathbb{R} ^ {N \times d}\)

\[S=QK^{T \in \mathbb{R} ^{N \times N},} P=softmax(S) \in \mathbb{R} ^{N \times N}, O=PV \in \mathbb{R}^{N \times N} \]

在上面的公式中有两个中间变量,分别是 \(S\)\(P\),每次生成这两个中间变量时都需要向显存中存储,之后计算下一步时再将其从显存中取出,但是最后这两个中间变量的值对我们又没有什么用。并且这两个中间变量矩阵非常庞大(\(S,P \in \mathbb{R}^{N \times N}\)\(N\) 的维度非常高,比如在 GPT 2 中,\(N\) = 1024),对显存的频繁大量访问将 Attention 操作变为了访存密集型程序,GPU 的算力资源被大大浪费了。

步骤 操作描述 (Operation) 内存操作 (I/O) 中间结果
1 加载 \(Q\)\(K\) 分块,计算得分矩阵 \(S = QK^T\) 读 (HBM) \(Q, K\) 分块,写 (HBM) \(S\) \(S\)
2 读取 \(S\),计算概率矩阵 \(P = \text{softmax}(S)\) 读 (HBM) \(S\)写 (HBM) \(P\) \(P\)
3 加载 \(P\)\(V\) 分块,计算输出矩阵 \(O = PV\) 读 (HBM) \(P, V\) 分块,写 (HBM) \(O\) \(O\)
4 返回 \(O\) - -

FlashAttention 核心思想

FlashAttention 旨在将 \(QK^T\)Softmax\(PV\) 三个步骤融合 (Fusion) 为一个 I/O 感知的操作,从而在 SRAM 上在线完成计算,避免将中间结果写入 HBM。
由于 SRAM 容量有限,无法加载全部 \(Q,K,V\) 矩阵,该方法利用 Tiling(分块) 技术,将大矩阵分解成小块,逐次加载进 SRAM 中进行计算,并将结果矩阵 \(O\) 一部分写回 HBM。
对于矩阵乘法 \(S=QK^T\)\(O=PV\),利用现有的分块乘法即可处理。然而,对于 \(P=softmax(S)\) 来说,由于 Softmax 固有的全局依赖性(必须看到输入 \(S\) 矩阵的每一行(对应一个 \(Q\) Token)的全部值才能正确计算归一化分母)。因此,如何在这种分块限制下保证 Softmax 的正确性,才是 FlashAttention 解决问题的核心挑战

稳定 Softmax 算法

Softmax 分块计算之前先介绍一下 Numercally Stable Softmax 算法。
原始 Softmax 的定义如下:

\[Softmax(x)_j = \frac{e^{x_{i}}}{ \sum_i e^{x_i}}​​ \]

目前常用的 FP16、FP32 浮点数格式所能表示的最大值分别是\(3.4 \times 10^{38}\)\(65504\),当 \(x^i\) 比较大时,\(e^{x_i}\) 就会溢出,无法被正确表示。而当 \(x_i\) 比较小时,\(e^{x_i}\) 将会非常接近 0,则 \(Softmax(x)_i\) 将会出现除 0 下溢错误。
在指数函数中存在一种特性称为平移不变性,考虑这个公式:

\[\frac {e^a}{e^b} \]

将分子分母同除一个常数 \(C=e^c\)

\[\frac {e^a/C}{e^b/C}=\frac {e^a/{e^c}}{e^b/{e^c}}=\frac {e^{a-c}}{e^{b-c}} \]

因此,我们可以根据 Softmax 的平移不变性这一特性对原始公式进行改造:首先,获取输入向量 \(x\) 中的最大值 \(m(x)\),然后将 Softmax 公式的分子和分母都除以 \(e^{m(x)}\),这确保了 Softmax 的值保持不变。在这一缩放操作过程中,由于分子 \(e^{x_i}\)​ 中的 \(x_i\) 减去了 \(m(x)\),因此所有的分子项 \(e^{x_i}​−m(x)\) 都将维持在一个 \((0,1]\) 的安全区间内,从而彻底消除了指数运算产生上溢的风险。同时,分母由所有这些分子项相加得到,并且其中至少存在一项 \(e^{{m(x)}−x_i=0}\) 等于 1,这使得分母能够保证大于等于 1,有效避免了分母为零导致的下溢问题

\[Softmax(x)_j​=\frac {e^{x_i}​−max(x_i)}{\sum_{i}​e^{{x_i}​−max(x_i)}}​ \]

完整公式如下:

\[\begin{gather} m(x):=\max_{i} x_{i} \tag{1} \\ f(x):=[e^{(x_1-m(x))} ... e^{x_{B-m(x)}}] \tag{2} \\ \ell(x):=\sum_{i} f(x)_{i} \tag{3} \\ softmax(x):=\frac {f(x)}{\ell(x)} \tag{4} \\ \end{gather} \]

Softmax 分块计算

现在有两个向量\(x^{(1)}, x^{(2)}\) ,将这两个向量拼接成一个向量 \(x=[x^{(1)}, x^{(2)}] \in \mathbb{R}^{2B}\),有以下公式:

\[\begin{gather} m(x)=m([x^{(1)} \quad x^{(2)}]) \tag{1} \\f(x)=[e^{m(x^{(1)})-m(x)}f(x^{(1)}) \quad e^{m(x^{(2)})-m(x)}f(x^{(2)})] \tag{2} \\\ell(x)=\ell([x^{(1)} \quad x^{(2)}]) = e^{m(x^{(1)})-m(x)}l(x^{(1)})+e^{m(x^{(2)})-m(x)}l(x^{(2)}) \tag{3} \\softmax(x)=\frac {f(x)}{\ell(x)} \tag{4} \end{gather} \]

对于第二个公式来说:

\[\begin{gather} f(x^{(1)}) = e^{x^{(1)}-m(x^{(1)})}\tag{1}\\f(x^{(2)}) = e^{x^{(2)}-m(x^{(2)})}\tag{2} \end{gather} \]

对于 \(f(x)\) 来说,并不像 \(m(x)\) 能够直接由 \(x^{(1)} \quad x^{(2)}\) 拼接而来,这是因为 \(f(x^{(1)})\)\(f(x^{(2)})\) 的缩放基准分别为 \(m(x^{(1)})\)\(m(x^{(2)})\),而非全局最大值 \(m(x)\)。为了将缩放基准转换为 \(m(x)\),需要找到一个合适的缩放因子 \(C_1\),使得 \(f(x^{(1)}) \cdot C_1 = e^{x^{(1)}-m(x)}\)。解一下这个方程,得出 \(C_{1} = e^{m(x^{(1)})-m(x)}\),同样的,得出 \(C_{2} = e^{m(x^{(2)})-m(x)}\),然后将这两个缩放因子与 \(f(x^{(1)})\)\(f(x^{(2)})\) 相乘,得出整个 \(f(x)=[e^{m(x^{(1)})-m(x)}f(x^{(1)}) \quad e^{m(x^{(2)})-m(x)}f(x^{(2)})]\)
我们再利用相反的思路来考虑这个问题,\(m(x)\) 肯定等于 \(m(x^{(1)})\)\(m(x^{(1)})\),我们就假设 \(m(x)=m(x^{(2)})\),则 \(f(x)\) 的第二项就只剩 \(f(x^{(2)})\) 了,将第一项展开为 \(e^{m(x^{(1)})-m(x)}e^{x^{(1)}-m(x^{(1)})}=e^{x^{(1)}-m(x)}=e^{x^{(1)}-m(x^{(2)})}\),由此,我们成功地将两个 Softmax 分子项 \(f(x^{(1))})\)\(f(x^{(2)})\) 都转换到了统一的、以 \(m(x^{(2)})=m(x)\) 为基准的数值空间。
对于公式 \((3)\) 来说,其原理与公式 \((2)\) 一致。

逐步更新

存在向量 \(x\)\(x = [x^{(1)}, x^{(2)}, \dots, x^{(N)}]\)
假设当前已经处理了前 \(k-1\) 块,得到了:

  • 当前全局最大值: \(m^{(k-1)}\)
  • 当前全局归一化项: \(\ell^{(k-1)}\)
    现在处理新块 \(k\),它有:
  • 局部最大值: \(m_k\)
  • 局部归一化项:\(\ell_k\)
    然后更新全局最大值 \(m^{(k)}\) 与全局归一化因子 \(l^{(k)}\)

\[\begin{gather} m^{(k)}=max(m^{(k-1)},m_{k}) \\ \ell^{(k)} = e^{m^{(k-1)-m^{(k)}}} \ell^{(k-1)} + e^{m_k-m^{(k)}}\ell_k \end{gather} \]

FlashAttention 算法伪代码

输入:

  • 矩阵 \(Q, K, V ∈ \mathbb{R}^{N×d}\)(存放在 HBM)
  • 片上 SRAM 容量 \(M\)

1. 设置分块大小

  • \(B_c = \lceil \frac {M} {4d}\rceil\)
  • \(B_r = min(\lceil {\frac {M} {4d}} \rceil, d)\)

2. 初始化(存于 HBM)

  • \(O = 0_{N×d}\)
  • \(\ell = 0_{N}\)
  • \(m = -\infty_{N}\)

3. 分块

  • \(Q\) 按行划分为
    \(T_r = \lceil \frac {N} {B_r} \rceil\)
    得到块:\(Q_1 … Q_{T_r}\),每块大小 \(B_r × d\)
  • \(K, V\) 按列划分为
    \(T_c = \lceil \frac {N} {B_c} \rceil\)
    得到:\(K_1 … K_{T_c}\)\(V_1 … V_{T_c}\),每块大小 \({B_c × d}\)
  • 输出及辅助变量对应分成:
    \(O_i, \ell_i, m_i\)(大小分别为 \(B_r × d\)\(B_r\)\(Br\)

4. 主循环

图片

一些自己的理解

分块大小的选择

为什么要这样选择分块大小?

  • \(B_c = \lceil \frac {M} {4d}\rceil\)
  • \(B_r = min(\lceil {\frac {M} {4d}} \rceil, d)\)
    这是因为 FlashAttention 的一次计算要在 GPU 的 SRAM 中至少存放四个矩阵:\(Q_{i}, K_{i}, V_{i} ,O_{i}\)
    这四个矩阵总大小为:\(2B_{c}d+ 2B_rd\)
    在 FlashAttention 的设计策略倾向于一半的 SRAM 分配给 \(K/V\),一半给 \(Q/O\)
    假设 \(B_r\)\(B_c\) 近似,则四个矩阵总大小为:\(4B_{r}d \approx 4B_{r}d\)
    为了保证这四个矩阵能够同时放入 SRAM,需要满足 \(B_{r} \leq \frac {M} {4d}\),而对于 \(B_r\) 的另一个值 \(d\) 来说,这是因为在计算 softmax 的时候,如果某行变得太大,则会降低计算效率。

主循环更新 \(O_i\) 操作

\(diag()\)函数为对角矩阵构造函数,该函数能够将一个向量转换为一个对角矩阵。
\(O_i = diag(\ell^{new}_i)^{-1}(diag(\ell_i)e^{m_i - m^{new}_i}O_i+ e^{\tilde{m}_{ij} - m^{new}_i} \tilde{P}_{ij}V_j)\)拆分为三部分:

  • Part1: \(diag(\ell_i)e^{m_i - m^{new}_i}O_i\)
  • Part2: \(e^{\tilde{m}_{ij} - m^{new}_i} \tilde{P}_{ij}V_j\)
  • Part3: \(diag(\ell^{new}_i)\)
    因此整个公式可以简化为:\(O_{i}= \frac {Part1 + Part2} {Part3}\)
    现在理解一下为什么要这样做,对于 Part1 来说,\(diag(\ell_i)O_i\) 操作在恢复旧的总和(\(\ell_i\)是旧的归一化常数),\(diag(\ell_i)e^{m_i - m^{new}_i}O_i\)再用新的归一化常数计算以新的 \(m^{new}_i\)为基准的旧总和;Part 2 是计算当前块 \(K_j​,V_j\)​ 的贡献项。它是将基于局部最大值 \(\tilde{m}_{ij}\)​** 算出的加权总和 \(\tilde{P}_{ij}\),通过乘以指数校正因子 \(e^{\tilde{m}_{ij}-m^{new}_i}\),平移校正为以新的全局最大值 \(m^{new}_i\)​ 为基准的贡献总和;Part3 作为新旧总和相加得到的新的全局总和的分母,也就是全局归一化常数 \(\ell^{new}_i\)
    经过以上操作,新的 \(O_i\) 就被计算出来了。
通俗类比:计算班级平均分

为了理解这个操作,我们可以想象你在计算一个不断有新学生进入的班级的平均分

  • \(O_i​\) (旧输出):相当于 “A组学生的平均分”
  • \(\ell_i​\) (旧分母):相当于 “A组学生的人数”
  • 新数据:相当于 “B组学生的总分”
  • 目标:计算 A组 + B组 的总平均分。
    现在不能直接把“A组平均分”和“B组平均分”加起来除以2。必须这样做:
  1. 反归一化(还原总分):
    A组总分=A组平均分×A组人数
    (对应公式中的 \(diag(\ell_i​)O_i​\))
  2. 合并:
    全班总分=A组总分+B组总分
    (对应公式中的括号内相加)
  3. 重归一化(算新平均):
    新平均分=全班总人数全班总分​
    (对应公式中的 \({diag(\ell^{new_i}​)^{−1}}\))
    FlashAttention 的公式就是在做这个“加权平均更新”,只不过它多了一个步骤:因为它处理的是指数函数 (\(e^x\)),所以还需要处理 \({m_i}\)​ 变化带来的“指数缩放修正”。
http://www.jsqmd.com/news/66754/

相关文章:

  • 2025年絮凝剂定制厂家权威推荐:碱式聚合氯化铝/生产聚合氯化铝/聚合氯化铝铁源头企业精选
  • Electron.js 详解、应用场景及完整案例
  • 2025年12月储罐厂家最新推荐,认准环宇化工,PP储罐、PPH储罐、PP搅拌罐、PHP搅拌罐、聚丙烯储罐、聚丙烯搅拌罐、真空储罐、PP真空储罐、PPH真空储罐、全系列防腐储罐安全靠谱
  • 详细介绍:React Native CLI的搭建
  • 2025年12月制造业咨询公司推荐:行业评测榜与选型对比指南
  • 2025年五大台球桌品牌技术实力排行榜:英森台球桌技术实力如
  • 2025年专业市场洞察公司排名,市场洞察哪家公司好
  • 2025年12月制造业咨询公司推荐:专业排行榜单全面解析
  • 2025年全国防盗门售后电话推荐,广东星月神防盗门售后全解析
  • 2025年五大专业的互联网医院代办咨询公司推荐,互联网医院申
  • 2025年五大液压打包机源头厂家排行榜,专业液压打包机生产厂
  • 2025年全国五大防盗门品牌公司推荐:北京春天防盗门售后服务
  • 2025年中国真空上料机生产厂家排名:推荐一下真空上料机制造
  • 2025数字文化展馆服务商TOP5权威测评:国创展览的规模怎
  • 2025年12月制造业咨询公司推荐榜:权威评测与综合实力对比分析
  • 2025年12月制造业咨询公司推荐:服务实力排行榜与实战效果评估
  • 2025年特装搭建厂家综合实力推荐:展厅设计施工/特装展台搭建/展台设计搭建源头服务商精选
  • 实用指南:Epoll 服务器实战教学:从 Poll 到高性能事件驱动模型
  • 2025年10月制造业企业管理咨询公司推荐:专业评测排行榜详细发布
  • 2025年10月企业管理培训公司推荐:行业领先机构排行分析
  • 2025年10月制造业企业管理咨询公司推荐:权威机构综合评测与排名
  • 具身智能:零基础入门睿尔曼机械臂(二)——从API例程到Python实操全解析
  • 【首发】在 RTX 5070 (Blackwell) + WSL2 上跑通最新模型 Meta SAM 3D 的终极方案 - 天马行空
  • Ubuntu命令行模式与图形化模式切换
  • 2025年10月制造业企业管理咨询公司推荐:综合实力排名与对比评价
  • 2025年安全滑触线企业权威推荐:多级滑触线/无接头滑触线/三级滑触线源头厂家精选
  • 国标GB28181算法算力平台EasyGBS在物业视频安防管理服务中的应用方案
  • 2025年10月企业管理咨询公司推荐:领先品牌排行及选择策略指南
  • OpenCV立体视觉深度匹配:SGBM算法详解
  • 2025成人口语学习机构TOP5权威推荐:破解成人英语语调口