Scalify:基于e-graph的分布式机器学习计算图等价性验证工具
1. 项目概述
在分布式机器学习的世界里,我们常常面临一个看似简单实则棘手的问题:我写的这个并行化代码,真的和单机版本在数学上等价吗?这个问题背后,是无数个深夜调试的工程师,是那些在数百个GPU上跑了一周才发现结果不对的昂贵实验,更是那些难以察觉却足以毁掉整个模型训练的“静默错误”。这些错误不会导致程序崩溃,它们悄无声息地改变着张量的布局、精度或通信逻辑,最终让训练出的模型性能远低于预期,而你却可能把原因归结为超参数没调好。
Scalify正是为了解决这个痛点而生。它不是一个运行时监控工具,而是一个编译时验证器。它的核心任务,是在你的分布式计算图(比如为32个GPU设计的张量并行计算流)被真正执行之前,就证明它在数学语义上与一个已知正确的“基线”计算图(通常是单机版本)完全等价。这听起来像是形式化验证的领域,但Scalify的巧妙之处在于,它没有选择传统的、计算量巨大的SMT求解器路径,而是引入了来自编程语言优化领域的“e-graph”(等式图)技术,结合了关系传播和符号推理,将验证时间从“天”级别压缩到了“分钟”级别。
想象一下,你正在为一个拥有4050亿参数的Llama-3.1模型实现Flash Attention的分布式版本。你写了一系列复杂的reshape和transpose操作来调整张量布局以适应多卡通信,同时加入了all-reduce来同步梯度。代码能跑通,loss也在下降,但你怎么能百分之百确定,经过这一系列眼花缭乱的变换后,最终每个设备上的计算结果,和你在单张卡上顺序执行的结果,在数学上是一模一样的?Scalify就是你的“数学证明助手”。它不关心具体的数值,而是关心计算的“形状”和“关系”。它通过分析计算图的结构、张量的维度映射以及操作符的语义,来推理两个图是否表达了同一个数学函数。
2. 核心原理:为什么是e-graph与关系传播?
要理解Scalify,首先要抛开“逐元素比对”的直觉。在分布式场景下,一个大的张量被切分(Shard)到多个设备上,每个设备只持有数据的一部分。直接比较两个图上对应节点的具体数值是行不通的,因为数据分布已经不同。Scalify的智慧在于,它比较的是“关系”而非“数值”。
2.1 e-graph:将等价性搜索转化为图重构问题
e-graph(Equality Graph)是一种数据结构,它能够高效地表示一个表达式集合,并记录这些表达式之间的等价关系。在传统编译器优化中,e-graph被用来做等式饱和(Equality Saturation),穷举一个表达式所有可能的等价形式,从而找到最优的实现。
Scalify创新性地将e-graph用于验证。它的工作流程可以概括为以下几步:
- 图转换:将输入的基线计算图(单机)和分布式计算图,分别转换为两个初始的e-graph。图中的每个节点代表一个操作(如
dot,add,reshape),边代表数据流。 - 规则应用:预定义一组“重写规则”(Rewrite Rules)。这些规则描述了在保持语义等价的前提下,一个计算子图可以被另一个子图替换。例如,一个
transpose后接reshape的操作序列,可能等价于另一个不同的reshape后接transpose的序列。Scalify将这些规则反复应用到两个e-graph上。 - 等价类合并:当规则应用表明两个节点(或子图)等价时,e-graph会将它们合并到同一个“等价类”中。这个过程是自反、对称和传递的。
- 最终判定:经过多轮规则应用和合并后,Scalify检查两个e-graph的“根节点”(即整个计算图的输出)是否被合并到了同一个等价类中。如果是,则证明两个计算图语义等价;否则,不等价。
注意:这里的“规则”不是任意的,而是基于严格的数学定义和分布式计算语义(如集合通信
all-reduce的语义)手工编码的。Scalify论文中提到了编码了约25条这样的元规则,覆盖了张量并行、专家并行等多种并行模式。
2.2 关系传播:在分布式上下文中建立桥梁
e-graph解决了“如何判断等价”的问题,但分布式图与基线图之间还存在一个根本差异:数据布局。这就是“关系传播”要解决的问题。
Scalify在验证之初,会为两个图中对应的输入张量(例如,同一个权重矩阵在单机版本和切分后的版本)建立一种“布局关系”。这种关系描述了分布式张量中的每个逻辑元素,对应到基线张量中的哪个位置。例如,对于一个形状为(1024, 4096)的矩阵进行按列切分的张量并行(TP=4),每个设备上的张量形状为(1024, 1024)。布局关系会记录:设备0上的张量A_shard[0]对应基线张量A的列[0:1024]。
验证的核心,就是让这种布局关系随着计算图的执行而“传播”下去。Scalify会沿着数据流图,分析每一个操作符(Operator)对输入输出张量布局的影响:
- 逐元素操作(如
add,mul,sin):不改变布局。如果输入张量A和A‘有布局关系R,那么经过add操作后,输出张量B和B‘也自动具有相同的布局关系R。 - 改变形状的操作(如
reshape,transpose):会改变布局关系。Scalify需要计算新的映射。这是最复杂、最容易出错的部分,也是Scalify算法(双射推断)的核心应用场景。 - 通信操作(如
all-reduce,all-gather):会同步或重组数据。Scalify需要根据通信操作的语义(例如,all-reduce是求和后广播)来更新布局关系。例如,一个按行切分的张量经过all-reduce后,每个设备都拥有了完整的行,其布局关系就变成了“复制”关系。
如果关系能够从输入一直无矛盾地传播到最终输出,并且最终输出的布局关系表明它们在逻辑上是同一个完整张量(或等价的复制品),那么验证就通过了。如果在某个节点,输入的关系无法推导出一致的输出关系,或者两个本应等价的分支出现了关系冲突,Scalify就会标记此处为“未验证”,并定位到可能的错误源头。
3. 核心挑战与突破:张量布局变换的等价性验证
在分布式机器学习优化中,大量的“静默错误”并非源于通信原语用错,而是隐藏在那些为了性能而引入的、复杂的张量布局变换序列中。不同的编译器优化Pass、不同的内核实现,可能会生成不同的reshape和transpose操作序列,但它们的目标是实现同一个逻辑上的数据重排。手动验证这些序列的等价性极其容易出错,而这正是Scalify的“符号双射推断”算法大放异彩的地方。
3.1 问题定义:从具体维度到符号表达式
让我们通过一个论文中的简化例子来理解这个问题。假设基线图中有一个操作序列:张量A (4, 64, 4096) -> reshape -> (256, 4096)我们称之为布局序列Sb。
在分布式图中,为了实现某种并行优化,编译器可能生成了另一个序列:张量A‘ (4, 64, 4096) -> transpose -> (64, 4, 4096) -> reshape -> (256, 4096)我们称之为布局序列Sd。
肉眼观察,(4,64,4096)经过transpose(1,0,2)变成(64,4,4096),再reshape成(256,4096),似乎和直接reshape成(256,4096)是等价的?但我们需要严格的证明。
Scalify的第一步是符号化。它将具体的维度数字(4, 64, 4096)用符号轴代替,例如(i, j, k)。这样,操作就变成了对符号表达式的变换:
Sb:(i, j, k) -> reshape -> (⊗(i, j), k)。这里⊗表示合并(merge)操作。Sd:(i, j, k) -> transpose -> (j, i, k) -> reshape -> (⊗(j, i), k)。
现在,问题转化为:表达式(⊗(i, j), k)和(⊗(j, i), k)是否在逻辑上表示同一个多维数组的索引空间?换句话说,是否存在一个双射(bijection)函数,能将Sd产生的数据布局,一一映射到Sb产生的布局上?
3.2 双射推断算法四步走
Scalify的算法2(Algorithm 2)清晰地描述了这一过程,我们可以拆解为四个步骤:
步骤1:生成符号表达式与轴映射如上所述,为两个布局序列Sb和Sd生成符号表达式Eb和Ed。同时,根据输入张量b和d的切分关系,建立一个初始的轴映射M。例如,如果d是b按第0维切分的一部分,那么M会记录d的轴i‘对应b的轴i的一个子区间。
步骤2:秩归一化Eb和Ed的“秩”(rank,即维度数量)可能因为reshape的合并/拆分而不同。例如,Eb可能是(⊗(i, j), k)(秩为2),而Ed可能是(j, i, k)(秩为3)。Scalify通过引入虚拟的、大小为1的维度,或将合并的维度临时拆开,将两个表达式归一化到相同的秩,得到Êb和Êd。如果无法归一化(例如,总元素数不同),则直接返回“无等价关系”。
步骤3:寻找置换双射这是算法的核心。现在Êb和Êd具有相同的秩。Scalify需要找到一个置换(permutation)p,使得将Êd的轴按照p重新排列后,其结构与Êb在轴映射M的意义下“相等”。 对于Êb中的每一个符号轴(如i),在Êd中寻找一个符号轴(如j),使得在映射M下,j和i代表的是基线张量中同一个逻辑轴(或它的一个部分)。如果对于Êb中的每个轴都能在Êd中找到唯一且不重复的对应轴,那么就成功构造了一个置换p。 在上述例子中,Êb = (i, j, k),Êd = (j, i, k)。显然,i对应Êd中的j,j对应Êd中的i,k对应k。因此得到的置换p = [1, 0, 2],这正好对应一个transpose(1, 0, 2)操作。
步骤4:构造操作序列最后,Scalify将推断出的双射具体化为一个可执行的操作序列。这个序列的作用是,当应用于分布式路径的末尾时,能将其张量布局转换为与基线路径兼容的布局。
- 如果
Êd与原始的Ed形状不同,先添加一个reshape操作,将Ed的形状变为Êd的形状(即归一化后的形状)。 - 添加
transpose(p)操作,应用步骤3找到的置换。 - 如果
Êd经过置换后的形状与Eb的最终形状不同,再添加一个reshape操作,变到最终目标形状。
对于我们的例子,算法会推断出双射操作序列为:[reshape(64, 4, 4096), transpose(1, 0, 2), reshape(256, 4096)]。这个序列恰好是Sd的逆过程!将它应用到分布式路径的末尾,就能“抵消”掉分布式图中额外的transpose,使其输出布局与基线路径一致。
3.3 实操心得:理解算法的局限性
这个算法非常强大,但它也有明确的适用范围,理解这一点对正确使用Scalify至关重要。
注意:Scalify的双射推断算法主要针对维度的合并与拆分(即
reshape操作)以及维度的重排(即transpose操作)。它假设布局变换主要由这两类操作构成,这在生产级ML框架(如PyTorch XLA、TensorFlow)中是非常普遍的。然而,它可能无法处理更广义的、非线性的索引变换(例如gather、scatter、strided_slice等复杂操作)。在实现自己的规则时,需要确保操作语义在算法定义的范畴内。
一个常见的陷阱:算法依赖于准确的轴映射M。如果初始的切分关系定义错误(例如,误以为张量是按行切分,但实际是按列切分),那么整个双射推断就会建立在错误的基础上,导致误判。因此,在配置Scalify验证任务时,明确定义每个输入张量的切分策略是第一步,也是最重要的一步。
4. 实战:使用Scalify定位静默错误
验证工具的价值不仅在于说“是”或“否”,更在于当答案是“否”时,它能告诉你“哪里出了问题”。Scalify的“后处理:基于E-Graph差异定位代码错误”模块正是为此设计。
4.1 错误定位机制
Scalify在将IR计算图转换为e-graph时,会通过编译器的日志API注入元数据,将图中的每个节点与源代码中的具体位置(文件、函数、行号)关联起来。当验证失败时,Scalify不是简单地列出所有“未验证”的节点——在不等价的图中,这样的节点可能非常多。
Scalify采用了一种更智能的溯源策略:
- 分类节点:在重写过程中,将所有节点分为“已验证”和“未验证”两类。
- 寻找分歧点:遍历所有“未验证”的节点,检查它的所有输入节点是否都是“已验证”的。
- 报告根源:如果一个节点的所有输入都已被验证为等价,但这个节点本身却无法被验证,那么这个节点就很可能是错误的根源。因为错误一定发生在这个节点或其内部操作上,而不是更早的祖先节点。
4.2 案例分析:错误的All-to-All布局变换
让我们剖析论文中图10(Figure 10)的经典案例。这是一个在混合并行(如同时使用张量并行和序列并行)中容易出现的错误。
- 基线图(Oracle):路径相对简单,
A和B进行矩阵乘法(dot)后得到C。 - 分布式图(Buggy):为了适应并行,对
A‘和B‘做了额外的reshape和transpose操作,然后进行dot得到C‘。之后,需要对C‘进行一个all-reduce通信操作来聚合结果。
错误出现在哪里?Scalify发现,A‘和B‘的布局变换序列,与C‘和C之间所需的布局关系不匹配。具体来说:
- 为了使
A‘与A对齐,需要应用一个双射[transpose(2,0,1), reshape(...)]。 - 为了使
B‘与B对齐,需要应用同一个双射。 - 但是,为了使
all-reduce之后的C‘与C对齐,需要应用一个不同的双射[transpose(2,1,0), ...]。
这就产生了一个矛盾:C‘是由A‘和B‘计算得来的,如果A‘和B‘遵循第一种变换关系,那么计算得到的C‘理应也通过第一种变换关系与C对齐。但实际代码中,all-reduce操作前后隐含的布局假设却是第二种关系。这种不一致导致Scalify无法将布局关系通过add操作(图中add节点)传播下去。
Scalify的输出:它会报告这个add节点是未验证的,并且其所有输入节点(即all-reduce的结果和另一个张量)都是已验证的。这直接将开发者的注意力引向了这个add操作,或者更准确地说,引向了产生add操作输入的上游代码——即那个有问题的all-reduce布局变换。错误信息会附带源代码行号(例如hlo.py:214),让开发者能迅速定位到问题代码。
4.3 常见错误类型与Scalify的检测能力
根据论文评估,Scalify可以有效检测以下几类分布式机器学习中典型的“静默错误”:
- 不正确的分布式操作:例如,该用
all-gather的时候用了all-reduce,或者多了一个不必要的all-reduce。Scalify通过通信原语的语义规则��检测。 - 不正确的分布式配置:例如,设备分组(replica groups)设置错误,导致只在部分设备上进行规约。Scalify通过分析操作符的设备属性关联来发现。
- 不一致的张量精度:单机图使用FP32,分布式图某个环节错误地使用了FP16或BF16。Scalify可以检查操作符的
dtype属性。 - 不正确的轴拆分:
reshape操作错误地拆分或合并了维度,破坏了张量切分关系。这是双射推断算法的主要检测目标。 - 不正确的布局优化:编译器或手动编写的布局变换序列存在错误,与基线逻辑不等价。同样是双射推断算法的检测目标。
实操心得:并非所有错误都能在计算图层面捕获。Scalify明确指出了其局限性:它专注于计算图IR级别的验证。因此,诸如运行时调度错误(如数据竞争)、在图形编译阶段之前发生的错误、或者那些不影响计算图语义但影响数值稳定性的超细微差别,Scalify可能无法发现。它是一款强大的“图形逻辑验证器”,但不能替代全面的集成测试和数值精度测试。
5. 性能与评估:真的能用于大模型吗?
一个验证工具如果速度太慢,就无法集成到开发流程中。Scalify论文中的评估数据有力地证明了其实用性。
5.1 验证时间:从数天到数分钟
论文对比了之前的SOTA工具TrainVerify。TrainVerify使用SMT求解器进行逐元素的推理,验证一个405B参数的Llama-3.1模型需要数天时间。而Scalify通过对张量进行整体关系推理,将验证时间缩短到了2分37秒(在6核AMD Ryzen 5 5600U CPU,16GB RAM的消费级机器上)。对于8B和70B的模型,验证时间更是低于2分钟。这个时间开销对于在代码提交前或CI/CD流水线中运行检查是完全可接受的。
5.2 可扩展性分析
Scalify的性能表现呈现出几个关键特征,这些特征源于其设计:
- 与张量形状无关:如图11a, 11b, 11e所示,改变序列长度(seqlen)、批大小(batch size)或注意力头数(heads)几乎不影响验证时间。因为Scalify在符号层面操作,计算图的节点和边数量不随具体维度大小变化。
- 与并行度弱相关:如图11d所示,增加张量并行度(TP degree)并不会显著增加验证时间。因为增加核心数主要是在计算图中添加更多的通信节点,而图的整体拓扑结构复杂度增长有限。
- 与模型层数线性相关:如图11c所示,验证时间随模型层数增加而线性增长。这是因为更多的层意味着更长的计算图,需要处理更多的节点。Scalify采用了层记忆化技术来优化这一点:它将模型按层分割成子图,验证完一层后,缓存该层的等价类信息,在验证下一层时可以直接复用,避免了重复计算,大幅提升了效率(见图12对比)。
5.3 真实漏洞检测效果
论文在AWS Transformers NeuronX框架上复现了19个真实世界中的历史bug,并利用Scalify进行检测。结果令人印象深刻:
- 检测率:17/19的bug被成功检测出(约89.5%)。
- 定位精度:对于其中13个bug,Scalify能精确定位到有问题的具体指令(
➸);对于另外4个bug,能定位到有问题的函数或数据结构(✻)。只有2个bug未被检测出,原因是它们发生在图编译阶段之外(如运行时KV缓存切片错误)。 - 发现新bug:在评估过程中,Scalify甚至在Amazon的SDK中发现了5个此前未知的bug,包括不正确的布局优化、错误的all-to-all变换、张量切分错误等,这些都已提交给开发者修复。
这些数据强有力地证明了Scalify不仅是一个学术原型,更是一个能在工业级复杂框架和模型上发现真实问题的实用工具。
6. 实现与应用展望
6.1 实现要点
Scalify的核心实现大约9000行Python代码,其中约6500行用于手工编码那25条关键的元规则。这些规则定义了不同并行模式(张量、专家、序列并行)下各种操作符的语义和等价变换。作者指出,一旦核心框架和基础规则集搭建完成,为新的并行技术添加规则支持所需的工作量是可控的(例如,为序列并行添加2条规则仅需30行代码)。
它构建在PyTorch XLA之上,直接操作ML模型的中间表示(IR)。其核心的布局等价推理引擎是egglog(一个e-graph库)。虽然原型针对AWS Neuron SDK,但其算法是框架无关的,可以移植到其他基于IR的系统,如TensorFlow XLA或Megatron-LM。
6.2 局限性
- 范围限制:专注于计算图验证。运行时错误、调度竞争、编译前错误无法捕获。
- 可靠性而非完备性:Scalify是“可靠的”(sound),即它验证通过的图一定是正确的。但它不是“完备的”(complete),意味着可能存在一些正确的图,由于规则集未覆盖或算法限制,当前版本的Scalify无法验证。这需要不断扩展规则库。
- 模式支持:目前对Tensor Parallelism, Flash Decoding, Expert Parallelism支持良好,但对更复杂的流水线并行(Pipeline Parallelism)等涉及复杂跨设备通信和运行时语义的模式,需要更多工作来扩展支持。
- 根因分析:Scalify能精确定位到出现不一致的代码行,但有时无法直接揭示错误的根本原因(例如,是哪个工程师在什么背景下引入了这个错误逻辑),这仍需开发者人工分析。
6.3 未来与启示
Scalify为分布式机器学习系统的可靠性工程树立了一个新的标杆。它表明,通过巧妙的抽象(e-graph、关系传播、符号推理),可以对大规模、复杂的计算图进行高效的形式化验证。这对于未来越来越庞大、并行策略越来越复杂的模型开发至关重要。
对于开发者和团队而言,可以考虑将此类验证工具集成到CI/CD流程中,作为对分布式优化代码变更的强制性检查项,从而在代码合并前就拦截可能导致静默错误的修改。同时,工具揭示的“无法验证”区域,也可以指导我们编写更完备的测试用例。
从更广阔的视角看,Scalify的成功是编程语言、形式化方法与系统工程交叉融合的典范。它解决的不是一个纯理论问题,而是扎根于工业实践、具有明确度量标准和显著效用的实际问题。随着大模型训练和推理日益成为基础设施,这类确保底层计算正确性的工具,其价值只会与日俱增。
