Scalify:基于e-graph与符号推理的分布式机器学习静默错误检测工具
1. 项目概述与核心挑战
在分布式机器学习的世界里,我们常常需要将一个庞大的模型拆解,分散到成百上千个计算设备(GPU、TPU、Neuron Core)上并行执行,以应对模型参数量和数据量的爆炸式增长。这个过程,我们称之为模型并行化。听起来很美好,但实际操作起来,就像指挥一个庞大的交响乐团,每个乐手(计算设备)都必须严格遵循乐谱(计算图)的指示,稍有差池,整场演出就会走调。这里的“走调”,在机器学习中就是静默错误——程序不崩溃、不报错,但悄无声息地输出了错误的结果。
想象一下,你训练了一个月的大模型,最后发现因为一个张量在某个设备上被错误地reshape或transpose了一下,导致所有努力付诸东流。这种错误极难通过常规测试发现,因为它们不引发异常,只导致模型精度下降或行为异常。Scalify正是为了解决这个痛点而生的工具。它的核心使命,是自动验证一个模型在单设备上运行的计算图(我们称之为基线图),与它在分布式环境下经过各种并行化优化(如张量并行、专家并行)后产生的计算图,在数学语义上是否完全等价。
简单来说,Scalify要回答的问题是:“我做的这些为了加速而进行的复杂变换,有没有改变模型原本要计算的东西?” 这个问题的答案,对于确保百亿、千亿参数大模型训练和推理的可靠性至关重要。接下来,我将深入拆解Scalify是如何运用e-graph(等式图)这一形式化工具,结合巧妙的符号推理,来高效、精准地完成这项看似不可能的任务的。
2. 核心原理:当e-graph遇见张量布局
要理解Scalify,首先得弄明白它手中的两把“利器”:e-graph和张量布局关系。
2.1 e-graph:将“等价”形式化
e-graph(等式图)是一种数据结构,它不直接存储计算表达式,而是存储表达式之间的等价关系。你可以把它想象成一个“等价类”的集合。在e-graph中,一个计算图里的每个操作(节点)都可以有多个不同的、但语义等价的表达形式,这些形式被归入同一个等价类(e-class)中。
例如,对于矩阵乘法(A @ B) @ C,根据结合律,它等价于A @ (B @ C)。在e-graph中,这两个表达式会被归入同一个e-class。Scalify利用这一点,将基线计算图和分布式计算图都“喂”给同一个e-graph。工具内部预置了一系列元规则,这些规则描述了在分布式环境下,哪些变换是保持语义不变的。例如,“一个All-Reduce操作后接一个加法,可以等价于先在本地做加法,再进行All-Redduce”(如果数据布局允许)。
Scalify的工作流程,就是不断地应用这些元规则,对e-graph进行重写和合并。如果最终,基线图中代表模型输出的节点,和分布式图中对应的输出节点,被合并到了同一个e-class中,那么我们就证明了这两个图在给定规则下是语义等价的。反之,如果它们始终无法合并,就说明存在不等价,即引入了静默错误。
注意:e-graph的重写和合并是“饱和”过程,它会探索所有已知规则下的等价形式。这比单纯地按路径比较两个图要强大得多,因为它能发现通过不同变换序列达到相同结果的等价性。
2.2 张量布局关系:跨越设备边界的桥梁
然而,在分布式场景下,事情没那么简单。一个张量可能被分片到多个设备上。例如,一个形状为[4096, 4096]的权重矩阵,在4路张量并行下,可能被按列切分成4个[4096, 1024]的碎片,分布到4个设备上。这时,设备上的局部张量(Sharded Tensor)和全局的逻辑张量(Global Tensor)之间,就存在一种布局映射关系。
Scalify的核心创新之一,就是显式地建模并跟踪这种布局关系。它为e-graph中的每个张量节点都附加了布局信息。这个信息不是一个具体的形状,而是一个符号化的轴映射表达式。
举个例子,假设基线张量B的形状是(i, j, k),表示三个维度的规模。在分布式图中,对应的张量D可能因为一个reshape操作,变成了(⊗(i, j), k),这里⊗表示将前两个维度合并。那么,B和D之间的布局关系S(B, D)就可以表述为:B的(i, j, k)轴分别映射到D的(⊗(i, j), k)轴。
Scalify在e-graph中传播等价关系时,会严格检查布局关系是否兼容。只有当两个操作(如加法)的输入张量不仅计算语义等价,而且它们的布局关系也一致时,Scalify才会认为这两个操作的结果是等价的,并合并它们的输出节点。这就像确保交响乐团中,不仅小提琴组和中提琴组演奏的旋律要对,他们的音高基准(布局)也必须一致,否则合奏就是混乱的。
3. 关键技术拆解:双射推断与错误定位
理解了基本原理,我们来看Scalify解决的两个最棘手的问题:如何判断复杂的布局变换序列是否等价,以及一旦发现不等价,如何精准定位bug。
3.1 符号化双射推断:破解布局变换的迷宫
分布式优化常常引入一连串的reshape和transpose操作来调整数据布局,以适应不同的并行策略或硬件特性。如图9所示,基线路径和分布式路径可能各自经过不同的变换序列,最终得到形状相同的张量。Scalify需要判断:这两个序列在数学上是否描述的是同一个张量重排?
它的方法是符号化双射推断,其算法(对应原文Algorithm 2)可以分解为四步:
生成符号表达式:将每条路径上的布局操作序列(如
reshape,transpose)转化为对张量轴的符号化操作。例如,形状(4, 64, 4096)被符号化为(i, j, k)。一个reshape((4,64,4096), (256,4096))操作,会被转化为将轴(i, j)合并为⊗(i, j),得到表达式(⊗(i, j), k)。秩归一化:比较两条路径的最终符号表达式。如果它们的“秩”(即维度数量)不同,Scalify会尝试通过引入虚拟的、规模为1的维度来进行归一化,使两者具有相同的秩,以便于比较。如果无法合理归一化,则直接判定为不等价。
寻找置换双射:这是核心步骤。Scalify试图在归一化后的分布式表达式
Êd和基线表达式Êb之间,找到一个轴的置换关系。它逐个检查Êb中的每个符号轴(如i),在Êd中寻找结构上相等的对应轴(考虑布局映射关系M)。如果能找到唯一且一一对应的映射,就形成了一个置换索引(如(1, 0, 2)),这对应一个transpose操作。构造操作序列:根据找到的置换关系,Scalify可以反向构造出一个操作序列。这个序列作用在分布式路径的终点张量上,能将其变换到与基线路径终点张量完全一致的布局。例如,生成的序列可能是
[reshape((256,4096), (64,4,4096)), transpose(1,0,2), reshape((64,4,4096), (256,4096))]。如果这个构造的序列能使得两条路径等价,则双射推断成功。
这个过程本质上是在符号层面进行图同构匹配。它避免了直接对具体数值进行昂贵的符号执行,而是利用张量操作的代数性质进行推理,效率极高。
实操心得:双射推断的成功高度依赖于对
reshape操作语义的精确建模——即它只能合并或拆分连续的维度。Scalify目前将范围限定于此,这覆盖了生产框架中绝大多数情况(如Megatron-LM、DeepSpeed中的张量分组),在实用性和复杂性之间取得了良好平衡。如果你想将其用于更诡异的维度重排,可能需要扩展这里的规则。
3.2 基于e-graph差异的代码级错误定位
验证失败(输出“不相等”)只是一个开始。对开发者来说,更重要的是知道bug在哪里。Scalify的另一个强大之处在于它能进行精确的错误定位。
它通过在编译过程中插桩,将中间表示(IR)图中的每个节点,都与源代码的抽象语法树(AST)节点关联起来,记录下文件名、函数名、行号等元数据。当e-graph重写完成并检测到不等价时,Scalify不会简单地列出所有“未验证”的节点——在复杂的非等价图中,这样的节点可能成千上万,毫无帮助。
Scalify采用了一种更聪明的策略:它遍历那些“未验证”的节点,但只关注那些输入节点已经被验证为等价的未验证节点。为什么?因为如果一个节点的所有输入都是等价的,但这个节点本身的输出不等价,那么问题很可能就出在这个节点所代表的操作上,或者其直接的变换上。
如图10所示的例子,一个add操作未被验证。Scalify检查发现,它的两个输入张量虽然各自都能通过某种双射与基线图的对应输入对齐,但对齐所需的双射序列却不相同。这意味着,在add操作执行之前,两个输入张量的数据布局已经不一致了,因此add无法进行有效的元素级加法。Scalify会报告这个add节点及其源代码位置,并指出其输入已验证但自身失败,从而将开发者的注意力直接引向导致布局分歧的根源——很可能是不正确的reshape或通信操作。
这种定位方式极大地缩小了调试范围,将海量的IR节点排查,变成了对少数几个“输入已验证但自身未验证”的关键节点的审查。
4. 实现与评估:在真实模型上的实战表现
理论再优美,也需要实践检验。Scalify的实现大约有9000行Python代码,其中约6500行用于手动编码25个针对不同并行模式(张量、专家、序列并行等)的元规则。它构建在PyTorch XLA之上,直接处理ML模型的中间表示,并与egglog集成作为其e-graph引擎。
4.1 验证能力与效率
评估显示,Scalify能够处理Llama-3.1(8B到405B参数)和Mixtral(8x7B, 8x22B)这样的真实世界大模型。所有验证均在几分钟内完成(见表2),运行在普通的6核CPU和16GB内存的机器上。这与之前一些基于SMT求解器的方法(如TrainVerify)形成了鲜明对比,后者对于405B的模型可能需要数天时间。
关键效率来源:
- 张量级抽象:Scalify在张量层面(而非元素层面)进行推理,将分片张量视为一个整体实体,极大减少了推理状态。
- 层记忆化与图分区:对于Transformer这类具有重复层结构的模型,Scalify采用了层记忆化技术。它验证完一个Transformer层后,会缓存该层的等价性证明,后续相同结构的层直接复用结果,避免了重复计算。同时,它将大计算图分区成子图分别处理,降低了单次e-graph重写的复杂度。
- 复杂度与模型规模解耦:如图11所示,Scalify的验证时间与输入张量的具体形状(序列长度、批大小)以及并行度(TP degree)无关,仅与模型的层数呈线性关系(图11c)。这是因为其推理完全在计算图的结构层面进行,不涉及具体数值。
4.2 错误检测效果
Scalify被设计用于检测五类典型的静默错误:
- 错误的分布式操作:例如,使用了不必要的
all-reduce,或者该用all-gather时用了reduce-scatter。 - 错误的分布式配置:例如,归约操作只在部分设备子集上进行。
- 不一致的张量精度:单机流水线和分布式流水线使用了不同的数值精度(如FP16 vs BF16)。
- 错误的轴切分:
reshape操作错误地分割了张量,破坏了分片关系。 - 错误的布局优化:与基线相比,使用了无效的布局变换序列。
在复现的19个历史真实bug中,Scalify成功检测出17个,并在1分钟内完成。更重要的是,它能将其中许多错误定位到具体的源代码行或可疑函数。此外,在评估过程中,Scalify还在AWS Neuron SDK中发现了5个此前未知的新bug,这些bug都可能导致严重的正确性问题,并已提交给开发者修复。
5. 局限性与未来方向
当然,Scalify并非万能。首先,它专注于计算图级别的验证。那些在图形编译阶段之后出现的错误,例如分布式流水线中的数据竞争、运行时内存错误等,超出了它的检测范围。其次,Scalify是可靠但不完备的。这意味着,所有被它验证通过的图,我们都可以相信是正确的(可靠性);但有些正确的图,可能因为其使用的变换超出了当前预定义元规则或双射推断的范围,而无法被验证(不完备性)。
目前,Scalify对Tensor Parallelism, Flash Decoding, Expert Parallelism等主流并行模式支持良好,但对于更复杂的流水线并行或涉及动态控制流的模式,需要额外的工程努力来扩展规则集。最后,虽然Scalify能精确定位到出现差异的代码行,但根因分析有时仍需开发者手动进行。未来的一个有趣方向是结合大语言模型(LLM),利用其代码理解能力,对Scalify定位出的可疑代码片段进行自动分析,推测可能的错误原因,从而进一步降低调试门槛。
从我个人的实践经验来看,像Scalify这样的形式化验证工具,正在成为大规模机器学习系统开发中不可或缺的“安全带”。它不能替代全面的测试,但能为那些最隐蔽、代价最高的静默错误提供一道强有力的防线。在动辄消耗数百万美元计算资源的千亿模型训练中,提前几分钟发现一个布局错误,其价值不言而喻。将验证左移,从运行时测试前置到编译时证明,是提升ML系统可靠性的必然趋势。
