当前位置: 首页 > news >正文

Strassen 矩阵分治乘法

矩阵乘法看起来太基础了,以至于很容易被当成一个已经“没有算法空间”的操作。两个 \(n \times n\) 矩阵相乘,按定义写三层循环,时间复杂度是 \(O(n^3)\)

\[C_{ij} = \sum_{k=1}^{n} A_{ik}B_{kj} \]

这个式子直接翻译成代码很自然:每个 \(C_{ij}\) 做一个长度为 \(n\) 的点积,一共有 \(n^2\) 个元素。朴素算法的优点也很明显:实现简单、内存访问模式容易优化、数值行为稳定,现代 BLAS 库里的 GEMM 本质上仍围绕这个计算结构做极致常数优化。

Strassen 算法提出的问题不是“能不能把三层循环写得更快”,而是更根本一点:

如果把矩阵乘法看成一个代数计算问题,\(n^3\) 次乘法真的是必须的吗?

1969 年,Volker Strassen 在 Gaussian elimination is not optimal 中给出了第一个渐进快于 \(O(n^3)\) 的通用矩阵乘法算法,也就是今天说的 Strassen 矩阵乘法。它的复杂度是 \(O(n^{\log_2 7})\),约为 \(O(n^{2.807})\)。(Springer)

从块矩阵开始看普通乘法

尝试分治,把两个矩阵各切成四块:

\[A = \begin{bmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{bmatrix}, \quad B = \begin{bmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{bmatrix} \]

按普通矩阵乘法,结果 \(C = AB\) 的四个块是:

\[\begin{aligned} C_{11} &= A_{11}B_{11} + A_{12}B_{21} \\ C_{12} &= A_{11}B_{12} + A_{12}B_{22} \\ C_{21} &= A_{21}B_{11} + A_{22}B_{21} \\ C_{22} &= A_{21}B_{12} + A_{22}B_{22} \end{aligned} \]

如果每个 \(A_{ij}\)\(B_{ij}\) 都是大小为 \(n/2\) 的子矩阵,那么这里需要 8 次规模为 \(n/2\) 的矩阵乘法,再加上一些矩阵加法。递归地做下去,并不会改变复杂度:\(T(n)=8T(n/2)+O(n^2)\),解出来仍然是 \(O(n^3)\)

Strassen 的突破点是:这 8 次子矩阵乘法可以被改写成 7 次。代价是更多的加减法。

先在 2 × 2 上把戏法拆开

为了避免块矩阵符号太重,先写两个普通 \(2 \times 2\) 矩阵:

\[A = \begin{bmatrix} a & b \\ c & d \end{bmatrix}, \quad B = \begin{bmatrix} e & f \\ g & h \end{bmatrix} \]

普通算法会计算:

\[\begin{aligned} C_{11} &= ae + bg \\ C_{12} &= af + bh \\ C_{21} &= ce + dg \\ C_{22} &= cf + dh \end{aligned} \]

这里有 8 次标量乘法:\(ae,bg,af,bh,ce,dg,cf,dh\)。Strassen 不直接算这些项,而是构造 7 个中间量:

\[\begin{aligned} M_1 &= (a + d)(e + h) \\ M_2 &= (c + d)e \\ M_3 &= a(f - h) \\ M_4 &= d(g - e) \\ M_5 &= (a + b)h \\ M_6 &= (c - a)(e + f) \\ M_7 &= (b - d)(g + h) \end{aligned} \]

然后用它们组合出结果:

\[\begin{aligned} C_{11} &= M_1 + M_4 - M_5 + M_7 \\ C_{12} &= M_3 + M_5 \\ C_{21} &= M_2 + M_4 \\ C_{22} &= M_1 - M_2 + M_3 + M_6 \end{aligned} \]

这个公式第一次看会有点像魔术。检查其中一个元素就能看出它不是近似,而是精确恒等式。以 \(C_{11}\) 为例:

\[\begin{aligned} M_1 + M_4 - M_5 + M_7 &= (a+d)(e+h) + d(g-e) - (a+b)h + (b-d)(g+h) \\ &= ae + ah + de + dh + dg - de - ah - bh + bg + bh - dg - dh \\ &= ae + bg \end{aligned} \]

中间多出来的项被安排成互相抵消。其他三个元素也一样。这里减少的是乘法次数,增加的是加减法次数。对标量来说,这不一定划算;但当 \(a,b,c,d,e,f,g,h\) 本身是大矩阵时,一次“乘法”代表一个规模为 \(n/2\) 的矩阵乘法,加减法只需要线性扫一遍子矩阵。递归放大后,收益就出现了。

把公式提升到矩阵块

现在把 \(a,b,c,d,e,f,g,h\) 换回矩阵块:

\[\begin{aligned} M_1 &= (A_{11} + A_{22})(B_{11} + B_{22}) \\ M_2 &= (A_{21} + A_{22})B_{11} \\ M_3 &= A_{11}(B_{12} - B_{22}) \\ M_4 &= A_{22}(B_{21} - B_{11}) \\ M_5 &= (A_{11} + A_{12})B_{22} \\ M_6 &= (A_{21} - A_{11})(B_{11} + B_{12}) \\ M_7 &= (A_{12} - A_{22})(B_{21} + B_{22}) \end{aligned} \]

结果块仍然是:

\[\begin{aligned} C_{11} &= M_1 + M_4 - M_5 + M_7 \\ C_{12} &= M_3 + M_5 \\ C_{21} &= M_2 + M_4 \\ C_{22} &= M_1 - M_2 + M_3 + M_6 \end{aligned} \]

这个提升成立的前提是矩阵块的加法、减法和乘法满足通常的环运算规则。它不依赖交换律,因为矩阵乘法本来就不交换;公式中乘法左右顺序没有被调换过。但它确实依赖减法,所以不能直接搬到没有减法的半环上,例如某些图算法里的 \(\min,+\) 乘法。

递归版本的 Strassen 可以写成下面这种结构:

strassen(A, B):if size(A) <= cutoff:return classical_gemm(A, B)split A into A11, A12, A21, A22split B into B11, B12, B21, B22M1 = strassen(A11 + A22, B11 + B22)M2 = strassen(A21 + A22, B11)M3 = strassen(A11,       B12 - B22)M4 = strassen(A22,       B21 - B11)M5 = strassen(A11 + A12, B22)M6 = strassen(A21 - A11, B11 + B12)M7 = strassen(A12 - A22, B21 + B22)C11 = M1 + M4 - M5 + M7C12 = M3 + M5C21 = M2 + M4C22 = M1 - M2 + M3 + M6join C11, C12, C21, C22

这里的 cutoff 不是实现细节里的小补丁,而是实际性能的关键。递归到底层标量没有意义,因为标量乘法虽然减少了 1 次,却引入了大量加减法、临时对象和函数调用。一般当递归到某个规模一下后,切回高度优化的普通矩阵乘法。

复杂度为什么变成 \(O(n^{2.807})\)

假设矩阵大小是 \(n \times n\),且 \(n\) 是 2 的幂。每一层递归做 7 次规模为 \(n/2\) 的子问题,并做 \(O(n^2)\) 的矩阵加减法,所以有递推式:

\[T(n) = 7T(n/2) + O(n^2) \]

根据 Master Theorem,主项来自递归子问题,因为 \(\log_2 7 \approx 2.807 > 2\)

\[T(n) = O(n^{\log_2 7}) \approx O(n^{2.807}) \]

这就是 Strassen 和普通分治的差别。普通分治只是把 \(n\) 变成 \(n/2\),但子问题数量仍是 8;Strassen 真正改变了递归树的分支因子。

不过复杂度公式也容易误导。\(O(n^{2.807})\) 只说明当 \(n\) 足够大时增长率更慢,不代表任何尺寸都总是更快。即使不考虑内存访问,、SIMD 等因素。Strassen 在减少大矩阵乘法的同时也引入了大量加减法,总运算量不一定更低。

数学公式写起来很短,工程实现会遇到几个更具体的问题。

尺寸不一定是 2 的幂

教科书常把 \(n\) 假设为 2 的幂。实际矩阵很少这么配合。最直接的办法是补零到下一个 2 的幂,但这可能浪费大量内存和计算。例如 \(1025 \times 1025\) 补到 \(2048 \times 2048\),代价非常难看。

更常见的处理是:只对能均匀切分的主体部分使用 Strassen,边缘部分回退到普通 GEMM;或者递归时允许不完全均分,把某些奇数尺寸分给其中一个子块。后者实现更复杂,因为切片、临时矩阵和结果拼接都要处理不规则尺寸。

临时矩阵会吃掉收益

按最直观的实现,每个 \(M_i\) 都需要构造输入临时矩阵,例如 \(A_{11}+A_{22}\)\(B_{11}+B_{22}\)。再加上输出中间量 \(M_1,\dots,M_7\),内存分配压力会很高。

这也是为什么实际实现通常会做调度优化:复用 buffer,边算边释放,避免把所有临时结果同时留在内存里。有些变体会使用 Winograd 形式,保持 7 次子乘法不变,但减少块加减法数量。它不改变渐进复杂度,改变的是常数和内存流量。

对程序员来说,一个有用的判断是:Strassen 省下的是一次大矩阵乘法,但多出来的是很多次全矩阵读写。如果矩阵规模还没有大到让乘法完全压倒内存流量,它就可能不划算。

数值误差

Strassen 不会“算错”,但在实数浮点计算里,它改变了加减法顺序,并引入了更多中间表达式。像 \(B_{12}-B_{22}\) 这样的操作可能发生消去,导致相对误差放大。普通 GEMM 也有舍入误差,但它的误差行为更直接,硬件和库也针对这种模式优化了很多年。

所以在数值线性代数里,Strassen 的使用会更谨慎。对整数、多项式、有限域这类精确代数场景,数值稳定性不是问题,Strassen 或类似快速乘法更容易发挥作用。对机器学习训练、仿真、图形渲染这类浮点密集场景,吞吐量、内存带宽、并行效率和可预测性往往比渐进指数更重要。

还有更快的算法吗

研究矩阵乘法复杂度时,常用 \(\omega\) 表示矩阵乘法指数:如果两个 \(n \times n\) 矩阵可以在 \(O(n^{\omega+\epsilon})\)\(n^{\omega+o(1)}\) 次域运算内完成,就说对应指数是 \(\omega\)。普通算法对应 \(\omega=3\),Strassen 对应 \(\omega \le \log_2 7 \approx 2.807\)

后来出现了 Coppersmith-Winograd 系列方法、laser method 及其改进。到 2024 年,Alman、Duan、Vassilevska Williams、Y. Xu、Z. Xu 和 Zhou 给出的上界已经推进到 \(\omega < 2.371339\),比此前的 \(\omega < 2.371552\) 更低。(arXiv)

这类结果很重要,因为大量理论算法会把矩阵乘法当成子程序,\(\omega\) 的下降会传导到图算法、代数算法、动态规划优化等问题。但它们通常不是你会在生产代码里调用的“更快 GEMM”。隐藏常数、构造复杂度和所需矩阵规模都太大。

这些算法常被称为 galactic algorithms:理论上最终会赢,但赢的规模远超现实机器能处理的范围。对于 Strassen 算法(\(\omega \approx 2.807\)),通常在 \(n\) 达到几千(例如 \(n \approx 4000\))时就能实际超越普通 \(O(n^3)\) 乘法,因此在部分科学计算中仍有应用;但对于 Coppersmith-Winograd 类算法及其后续改进(如 \(\omega < 2.371339\)),隐藏常数极大,通常需要 \(n\) 超过 \(10^{12}\) 乃至 \(10^{40}\) 才能体现出渐进优势——这意味着即便用宇宙中所有粒子作为计算单元,也无法在合理时间内达到这一规模,因此它们纯粹是理论上的“星系级算法”,永远不会出现在生产代码中。

近几年还有一个有趣方向是自动搜索小尺寸矩阵乘法算法。DeepMind 的 AlphaTensor 把寻找矩阵乘法张量分解建模成单人游戏,用强化学习搜索可证明正确的乘法方案,并在若干小尺寸或特定代数结构上找到新的算法。这个方向和 \(\omega\) 前沿不完全是一回事:它更像是在“算法设计空间”里找可用构件,有些结果关注乘法次数,有些也关注特定硬件上的运行时间。(Nature)

现实里矩阵乘法通常怎么做

如果目标是把一个浮点 dense GEMM 跑快,工程路径通常不是先上 Strassen,而是把普通三层循环改造成适合硬件的形状。现代高性能 GEMM 会围绕 cache blocking、packing、寄存器分块、SIMD/FMA、线程划分和 NUMA 行为做优化。

BLIS/GotoBLAS 风格的实现很能说明这个思路:大矩阵被分成适合 cache 的 panel 和 block,核心是一个很小的 micro-kernel,负责更新 \(m_R \times n_R\) 的 C 微块;A 和 B 的数据会被 pack 到连续内存中,让 micro-kernel 以更友好的步长读取。BLIS 文档和论文中也明确把 GEMM 组织成围绕 micro-kernel 的多层循环结构,并讨论了 packing 对大矩阵局部性的作用以及小矩阵上 packing 可能带来的负担。

这和 Strassen 优化的是不同层面。Strassen 减少“数学乘法”的数量;BLAS/GEMM 优化让每次乘加尽量贴近硬件峰值。对于常见神经网络训练和推理,GPU/TPU 上的矩阵乘法还会进一步利用 tensor cores、低精度格式、tiling 和融合算子。实际瓶颈经常是数据搬运、布局转换、kernel launch、batching,而不是 \(O(n^3)\) 这个符号本身。

这并不意味着 Strassen 没有实践价值。它在大规模 dense 矩阵、精确代数、某些计算机代数系统中可以有意义;在数值计算库里,也可能作为可选路径出现在较大尺寸矩阵上。但它通常会被限制递归层数,并在底层回退到高度优化的 GEMM。真正的工程问题不是“Strassen 是否渐进更快”,而是“在当前矩阵规模、数据类型和硬件上,省掉的乘法能不能覆盖额外内存流量和误差风险”。

如果要自己实现一个较高效率的 Strassen,合理的版本大概是这样:

matmul(A, B):if not square_enough(A, B):return gemm(A, B)if min_dimension(A, B) < cutoff:return gemm(A, B)if estimated_temporary_memory_too_large(A, B):return gemm(A, B)return one_or_two_levels_of_strassen(A, B, base=gemm)

这里的 square_enough 是因为 Strassen 最自然适合方阵或接近方阵的 dense 乘法。非常瘦长的矩阵乘法,例如 \(m \times k\)\(k \times n\) 且某个维度很小,直接 GEMM 往往更好。estimated_temporary_memory_too_large 则是为了防止递归临时对象把 cache 和内存带宽打爆。

如果要继续优化,可以把递归层数作为参数,而不是只用 cutoff:

strassen(A, B, depth):if depth == 0:return gemm(A, B)if bad_shape_or_small_size(A, B):return gemm(A, B)split, compute seven products with depth - 1combine

这种写法更容易 benchmark。比如只试 1 层 Strassen、2 层 Strassen,再比较底层全用 GEMM 的耗时。很多时候 1 层已经能体现收益,继续递归反而让内存系统成为瓶颈。

总结

Strassen 最有价值的地方是它展示了一种分治算法的非平凡形态。普通分治只是切问题;Strassen 先改变问题的代数表示,再切问题。它用线性组合制造中间量,用抵消关系恢复正确结果,从而把递归分支数从 8 降到 7。

而在实际代码里,朴素 \(O(n^3)\) 乘法实际上反而更常用,因为他在中小规模下最容易经过硬件友好的实现后增强;在前沿理论研究里,矩阵乘法指数仍在缓慢下降;而 Strassen 作为较早提出的优化,既是漂亮的复杂度突破,也确实能在某些现实条件下使用。

http://www.jsqmd.com/news/879815/

相关文章:

  • 2026年宁波口碑好、专业、质量过硬且售后服务优质的手机维修店铺综合实力排行榜 - 资讯纵览
  • 2026年东莞冻品批发渠道分析:线上平台如何重塑传统采购模式 - 资讯纵览
  • 量子计算机的核心技术难点
  • 栈以及队列的详细讲解
  • 2026年5月优秀的气动蝶阀/气动截止阀厂家推荐钢特阀门科技有限公司 - 品牌鉴赏师
  • 2026年5月江门蓬江地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • HashMap 源码解析 底层原理 面试如何回答
  • 企业如何利用Taotoken实现多模型API的统一管理与访问控制
  • 驾照证件照怎么制作?2026驾驶证照片规范+手机制作教程 - 科技大爆炸
  • 多版本滤波算法对比试验
  • 2026 年成都钢板厂家及采购优选推荐 四川盛世钢联钢厂联营资源等你来抢 - 四川盛世钢联营销中心
  • 喜马拉雅xm-sign v3算法逆向解析与Node.js本地生成
  • 如何快速将视频格式转换为MP4?MKV、FLV、MOV转MP4就这么简单!
  • 医疗AI模型窃取攻击:原理、风险与超声影像场景的防御实践
  • 用 AutoGen 编排多智能体协作,让 AI 团队帮你干活
  • 2026年5月江门台山地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • 将taotoken接入openclaw agent工作流的配置要点
  • 2026年5月济宁梁山地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • Java方法全解析:从基础定义到重载机制
  • 漏洞研究工作流:从CVE追踪到实战提升的闭环方法论
  • 如何发起投票活动,投票小程序操作指南 - 资讯纵览
  • 新手教程使用curl命令快速测试Taotoken的OpenAI兼容接口
  • Grafana 从零上手:安装部署、仪表盘导入导出及插件安装完整指南
  • 如何发布一场投票评选活动,投票小程序操作指南 - 资讯纵览
  • 2026 出海 GEO 避坑指南:源码技术成试金石,旗引科技领跑国产第一梯队 - 资讯纵览
  • B4A要编绎成Release发布APP/waiting for ide debugger to connect
  • 2026年5月济宁曲阜地区黄金回收白银铂金回收门店推荐TOP1 地址及联系方式 - 诚信金利回收
  • 2026年中国出海GEO行业深度观察:源码私有化部署成为技术分水岭 - 资讯纵览
  • 基于决策树与Boosting的暗网流量多阶段分类系统设计与实践
  • 终极AMD Ryzen调试工具:免费开源的硬件掌控神器