当前位置: 首页 > news >正文

DiT 技术详解:把扩散模型的 U-Net 换成 Transformer,真正改变了什么

DiT 技术详解:把扩散模型的 U-Net 换成 Transformer,真正改变了什么

如果只用一句话解释 DiT:它把 latent diffusion 里的 U-Net 去噪网络换成了一个 ViT 风格的 Transformer。输入不再是卷积网络逐层处理的 feature map,而是 VAE latent 被切成的 patch token;扩散 timestep、类别标签等条件也不再靠 U-Net 里的 time embedding 到处注入,而是通过 adaLN-Zero 调制 Transformer block。

这句话听起来很像“架构替换”,但 DiT 真正有意思的地方不在“Transformer 也能生成图像”。更关键的是,它把图像扩散模型带到了一个更像 LLM/ViT 的缩放问题里:模型宽度、深度、token 数、forward Gflops 和 FID 之间出现了很清楚的关系。论文最硬的一句话是:DiT 的 Gflops 越高,FID 越低,而且这个趋势可以通过加深加宽 Transformer 或减小 patch size 来获得。

本文按工程视角拆 DiT。先把它放回 latent diffusion 的 pipeline,再看 patchify、条件注入、adaLN-Zero、缩放规律和代码实现。你不需要先背完整 DDPM 推导;只要知道扩散模型训练一个网络ϵθ(xt,t,c)\epsilon_\theta(x_t,t,c)ϵθ(xt,t,c)去预测噪声,DiT 改的是这个网络的 backbone。

DiT 在扩散 pipeline 里替换的是哪一块

Stable Diffusion 这类 latent diffusion model 通常有三块:VAE encoder 把图像压成 latent,去噪网络在 latent 空间里跑扩散反推,VAE decoder 再把 latent 解回像素图。DiT 保留了这条路线,只替换中间的去噪网络。

以 256×256 RGB 图像为例,论文使用 Stable Diffusion 的预训练 VAE,downsample factor 是 8。所以图像x∈R256×256×3x \in \mathbb{R}^{256 \times 256 \times 3}xR256×256×3会被编码成z∈R32×32×4z \in \mathbb{R}^{32 \times 32 \times 4}zR32×32×4。DiT 不是直接处理 256×256 像素,而是处理这个 32×32×4 的 latent grid。这样做很重要,因为如果直接在像素空间把图像切成 token,序列长度和算力会马上爆掉。

可以把 DiT 的位置画成这样:

image x │ ▼ VAE encoder E 冻结,不训练 │ ▼ latent z0: 32×32×4 │ add noise at timestep t ▼ noised latent zt │ ▼ DiT backbone: patchify → Transformer blocks → unpatchify │ ▼ predicted noise / covariance │ sampling loop ▼ latent z0_hat │ ▼ VAE decoder D 冻结,不训练 │ ▼ generated image

这里有一个容易被忽略的点:DiT 并没有提出新的扩散目标,也没有换掉 classifier-free guidance。它沿用 ADM/LDM 里很成熟的训练和采样设定,包括 learned covariance、250-step DDPM sampling、FID-50K 评估等。论文的实验设计其实很克制:尽量少动 diffusion recipe,把变量集中到 backbone 上。

这也是为什么 DiT 的结论比较干净。它是在问:如果把 U-Net 这个默认选择换成标准 Transformer,扩散模型还能不能按 compute scaling 的方式变好?答案是能,而且趋势相当稳定。

patchify:latent grid 怎么变成 token 序列

DiT 继承 ViT 的第一步:patchify。输入 latent 的形状是I×I×CI \times I \times CI×I×C,patch size 是p×pp \times pp×p,那么 token 数是:

T=(I/p)2 T = (I / p)^2T=(I/p)2

每个 patch 被线性投影到 hidden dimensionddd,再加上固定的二维 sine-cosine positional embedding。对于 256×256 图像,latent spatial size 是I=32I=32I=32。如果使用 DiT-XL/2,patch sizep=2p=2p=2,token 数就是16×16=25616 \times 16 = 25616×16=256。如果是 DiT-XL/4,token 数降到8×8=648 \times 8 = 648×8=64。如果是 DiT-XL/8,就只有4×4=164 \times 4 = 164×4=16个 token。

这给 DiT 带来一个很直接的旋钮:减小 patch size 会增加 token 数,也会显著增加 Transformer 的计算量。论文里说得很明确,patch size 减半会让 token 数变成四倍,因此 Transformer Gflops 至少变成四倍。更微妙的是,减小 patch size 几乎不增加参数量,因为参数主要在 Transformer block 的权重里,不在 token 数里。

这点和普通 CNN scaling 不太一样。你可以在参数量几乎不变的情况下,通过让模型处理更多 token 来提高 forward compute。DiT 的实验显示,这种 compute 增加确实能改善 FID。换句话说,DiT 的质量不只由参数量决定,也由“每次去噪到底看了多少 token、做了多少 attention/MLP 计算”决定。

官方 PyTorch 代码里对应的是这一行:

x=self.x_embedder(x)+self.pos_embed# (N, T, D)

x_embedder来自 timm 的PatchEmbed。后面所有 DiT block 都处理这个 token sequence。最后再把每个 token 解码成p×p×2Cp \times p \times 2Cp×p×2C,其中2C2C2C是因为模型同时预测噪声和 diagonal covariance。

DiT block:标准 Transformer,但条件注入不能随便做

把 latent patch 变成 token 以后,最自然的想法是直接套 ViT block:LayerNorm、self-attention、MLP、residual。问题在于,扩散模型不是普通图像分类。去噪网络每一步都需要知道 timestepttt,class-conditional ImageNet 还需要类别标签ccc。这些条件怎么进入 Transformer block,会明显影响效果。

DiT 论文比较了四种做法:

条件注入方式做法额外计算论文里的结论
In-context conditioning把 timestep 和 class embedding 当作额外 token 拼进序列很小简单,但效果较差
Cross-attention图像 token self-attention 后,再对条件 token 做 cross-attention最高,约 15% overhead计算更贵但不占优
adaLN用条件向量生成 LayerNorm 的 scale/shift很小比前两者更高效
adaLN-Zero在 adaLN 上加 residual gate,并把 gate 初始化为 0很小最好,后续实验默认使用

这个结果有点反直觉。很多 text-to-image 模型里 cross-attention 是核心部件,所以容易下意识觉得 cross-attention 更强。但在 DiT 的 ImageNet class-conditional 设置里,条件只是 timestep 和 class label,信息量很小。为这种短条件专门加 cross-attention,不一定划算。

adaLN 的思路更像 FiLM:先把 timestep embedding 和 label embedding 相加,得到条件向量ccc,再由一个 MLP 生成每个 block 的调制参数。官方实现里,每个 DiTBlock 的调制层输出 6 组向量:

shift_msa,scale_msa,gate_msa,shift_mlp,scale_mlp,gate_mlp=\ self.adaLN_modulation(c).chunk(6,dim=1)

然后 attention branch 和 MLP branch 分别这样走:

x=x+gate_msa.unsqueeze(1)*self.attn(modulate(self.norm1(x),shift_msa,scale_msa))x=x+gate_mlp.unsqueeze(1)*self.mlp(modulate(self.norm2(x),shift_mlp,scale_mlp))

modulate很简单:

defmodulate(x,shift,scale):returnx*(1+scale.unsqueeze(1))+shift.unsqueeze(1)

也就是说,条件向量不作为 token 参加 attention,而是改变每个 block 里 normalization 后的表示,并通过 gate 控制 residual branch 的强度。

adaLN-Zero 为什么是 DiT 的关键小改动

adaLN-Zero 的“Zero”不是名字装饰。它把每个 DiT block 初始成接近 identity function。具体做法是让adaLN_modulation最后一层线性层的权重和 bias 初始化为 0,于是初始时 shift、scale、gate 都是 0。残差分支一开始被 gate 关掉,整个 block 更像恒等映射。

官方代码里可以直接看到:

forblockinself.blocks:nn.init.constant_(block.adaLN_modulation[-1].weight,0)nn.init.constant_(block.adaLN_modulation[-1].bias,0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight,0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias,0)nn.init.constant_(self.final_layer.linear.weight,0)nn.init.constant_(self.final_layer.linear.bias,0)

这和一些 ResNet/扩散 U-Net 的初始化习惯一致:残差块刚开始不要太激进,先让网络从稳定的近似恒等映射开始学。扩散模型的训练本来就要处理不同 noise level 下的输入,如果一开始每个 block 都强烈扰动 token,训练会更难。

论文的消融很有说服力。四种 block 设计都在 DiT-XL/2 上比较,Gflops 大致接近:in-context 119.4,cross-attention 137.6,adaLN 118.6,adaLN-Zero 118.6。结果是 adaLN-Zero 在训练过程中 FID 最低,400K steps 时几乎把 in-context 的 FID 降到一半。这个结论说明 DiT 不是“把 ViT 塞进 diffusion 就行”。条件注入和初始化是成败点。

从工程实现看,adaLN-Zero 还有一个优点:它不像 cross-attention 那样引入额外条件序列,也不强依赖复杂的注意力 mask。对于 class-conditioned 或 timestep-conditioned 模型,这种调制式条件注入非常干净。

DiT-S/B/L/XL 和 /2 /4 /8 到底怎么命名

DiT 的模型名有两部分:前面的 S/B/L/XL 表示 Transformer 主体大小,后面的/2/4/8表示 latent patch size。

论文的主配置如下:

模型层数 Nhidden size dheads在 I=32, p=4 时的 Gflops
DiT-S1238461.4
DiT-B12768125.6
DiT-L2410241619.7
DiT-XL2811521629.1

DiT-XL/2就是 28 层、hidden size 1152、16 heads、patch size 2。对于 256×256 图像,它处理 32×32×4 latent,patch size 2 产生 256 个 token。对于 512×512 图像,latent 是 64×64×4,patch size 2 产生 1024 个 token,所以 512 分辨率下的 DiT-XL/2 Gflops 会明显更高。

论文里最值得记的不是某个具体配置,而是两条 scaling 路线:

  1. 固定 patch size,加深加宽 Transformer:S → B → L → XL。
  2. 固定模型大小,减小 patch size:/8 → /4 → /2。

两条路都会增加 Gflops,也都会改善 FID。更关键的是,参数量不是唯一解释变量。比如固定 DiT-XL,只把 patch size 从 4 改成 2,参数量几乎不变,但 token 数和 Gflops 大幅增加,FID 仍然显著变好。

这就是 DiT 对后续生成模型架构的影响:图像生成不必永远围绕 U-Net 设计。只要把输入组织成 token,并找到合适的条件注入方式,扩散模型也能进入 Transformer 的 scaling 逻辑。

训练设置:DiT 的 recipe 其实很保守

DiT 的实验是在 ImageNet class-conditional generation 上做的,分辨率包括 256×256 和 512×512。训练时使用 AdamW,batch size 256,学习率1×10−41 \times 10^{-4}1×104,没有 weight decay,只用 horizontal flip 作为数据增强。论文还提到,他们没有发现 ViT 训练里常见的 warmup 或正则化是必需的,训练过程稳定,没有观察到 Transformer 训练中常见的 loss spike。

扩散部分基本沿用 ADM:1000-step linear variance schedule,预测噪声和 learned covariance,采样评估用 250 DDPM steps。评估用 FID-50K,并用 ADM 的 TensorFlow evaluation suite 来保证和 prior work 可比。

最大模型的训练成本不低。论文报告 DiT-XL/2 在 TPU v3-256 pod 上训练速度约 5.7 iterations/second,global batch size 256。官方 PyTorch repo 后来提供了 DDP 训练脚本,也说明用 8×A100 训练 DiT-XL/2、用 4×A100 训练 DiT-B/4,可以在数十万步范围内复现 JAX 结果到合理随机波动内。

如果只是跑预训练模型,官方 repo 给出的最小命令很简单:

gitclone https://github.com/facebookresearch/DiT.gitcdDiT condaenvcreate-fenvironment.yml conda activate DiT python sample.py --image-size512--seed1

训练自己的 class-conditional DiT:

torchrun--nnodes=1--nproc_per_node=N train.py\--modelDiT-XL/2\--data-path /path/to/imagenet/train

如果要严肃复现实验,最容易踩坑的是 FID 评估而不是模型 forward。FID 对 resize、采样数量、VAE decoder、guidance scale 都敏感。官方 README 里强调,PyTorch 训练结果表里的 FID 是 250 DDPM sampling steps、mseVAE decoder、无 guidance(cfg-scale=1)条件下算的。

实验结果该怎么看:不是“Transformer 赢了”,而是“compute scaling 很干净”

DiT-XL/2 在 256×256 ImageNet 上,使用 classifier-free guidance scale 1.50 时 FID-50K 达到 2.27,超过论文比较中的 prior diffusion models。512×512 上,DiT-XL/2-G 达到 3.04 FID,也优于当时对比的 ADM、ADM-U、ADM-G 等结果。

但我觉得这篇论文真正有价值的结果不是 SOTA 表格,而是 Figure 6、Figure 8、Figure 9 那组 scaling 分析。它们回答了三个更底层的问题:

第一,增加 Transformer depth/width 有用吗?有。固定 patch size 时,从 S/B/L 到 XL,FID 在训练各阶段都改善。

第二,增加 token 数有用吗?也有。固定模型大小时,把 patch size 从 8 降到 4、再降到 2,FID 同样持续改善。

第三,小模型多采样几步能不能补回来?很难。论文比较了不同 sampling steps 下的 FID,发现增加 sampling compute 不能弥补 backbone compute 不足。比如 DiT-L/2 用 1000 sampling steps 时,每张图采样计算量约 80.7 Tflops;DiT-XL/2 用 128 steps 只用约 15.2 Tflops,但 FID-10K 仍然更好(23.7 vs 25.9)。

这对实际训练很有启发。扩散模型的质量不是只靠“采样时多跑几步”堆出来的。backbone 本身的容量和每步 forward compute 仍然很关键。对于 DiT,训练一个足够大的模型,可能比在小模型上用更重的 sampler 更划算。

从代码看一次 forward

把官方models.py抽象一下,DiT forward 主要是五步:

defforward(self,x,t,y):# 1. latent patches -> token sequencex=self.x_embedder(x)+self.pos_embed# 2. timestep / label embeddingst=self.t_embedder(t)y=self.y_embedder(y,self.training)c=t+y# 3. Transformer blocks with adaLN-Zeroforblockinself.blocks:x=block(x,c)# 4. decode each token to patch predictionx=self.final_layer(x,c)# 5. token sequence -> latent gridx=self.unpatchify(x)returnx

这里的x不是像素图,而是 noisy latent。t是 diffusion timestep。y是类别标签,训练时会按class_dropout_prob随机 drop 成 null label,用来支持 classifier-free guidance。

官方forward_with_cfg还有一个实现细节:为了可复现,默认只对前三个 channel 应用 classifier-free guidance,而不是对所有 output channel。代码注释说标准做法可以改成对所有 channels 做 CFG。这种细节如果不注意,复现出来的采样结果可能和 README 或论文不一致。

另一个实现细节是位置编码。DiT 使用固定二维 sin-cos positional embedding,requires_grad=False。这和 ViT/MAE 的习惯一致,也让模型结构更简单。DiT 本质上把 latent grid 当成一张“低分辨率图像”,所以二维位置编码比一维 learnable embedding 更自然。

DiT 和 U-Net 的差别,不只是有没有卷积

U-Net 的强项是局部归纳偏置和多尺度结构。高分辨率图像生成里,U-Net 通过 downsample/upsample 路径在不同空间尺度处理特征,skip connection 又保留细节。这个设计很适合图像。

DiT 的强项是统一的 token 表示和清晰的 scaling 规则。它没有显式金字塔,也没有 U-Net 的多尺度 skip。所有 token 在同一 hidden dimension 里反复经过 attention 和 MLP。局部性不是结构硬编码出来的,而更多来自 latent patch、位置编码和训练数据。

这不是说 DiT 在所有场景都天然优于 U-Net。原始 DiT 的 ImageNet class-conditional 设置比较干净,条件也很短。如果换成 text-to-image,条件变成长文本,cross-attention 或更复杂的 multimodal attention 又会回来。后来的 MMDiT、PixArt、Stable Diffusion 3 等路线,本质上都是在 DiT/Transformer backbone 上重新设计文本条件、训练效率和高分辨率生成。

所以更准确的判断是:DiT 证明了扩散模型不需要永远依赖 U-Net inductive bias。Transformer backbone 可以在 latent diffusion 里工作,而且可以按 compute 规律稳定变好。但具体到 text-to-image、video、3D 或 controllable generation,条件组织和训练 recipe 仍然决定上限。

实践里什么时候该考虑 DiT

如果你在做 image/video generation 或多模态生成模型,DiT 值得考虑的场景通常有几类。

第一,模型规模会继续变大。Transformer 的工程生态更成熟:FlashAttention、sequence parallel、tensor parallel、checkpointing、fused MLP、KV/attention 优化,这些都更容易迁移到 DiT 类架构上。U-Net 当然也能优化,但 Transformer scaling 的工具链更完整。

第二,你的输入和条件天然是 token。比如文本、动作、语音、相机轨迹、agent state、layout token、视频 patch。如果所有东西都能变成 token,那么 Transformer backbone 的统一接口会很舒服。相反,如果任务强依赖局部纹理和多尺度 skip,U-Net 仍然可能更省算力。

第三,你关心 scaling law 式的实验设计。DiT 给了一个很清楚的坐标系:模型大小、patch size、token 数、Gflops、训练步数、FID。你可以系统扫配置,而不是只在 U-Net channel multiplier、attention resolution、resblock 数量里调参。

实际落地时,我会先看三个约束:

约束更偏 DiT更偏 U-Net
数据/算力有足够训练 compute,计划 scale算力有限,需要强 inductive bias
条件形式多模态 token、长上下文、需要统一建模条件简单,局部控制为主
工程目标想复用 Transformer 优化栈想用成熟 diffusion U-Net 生态

DiT 不是免费午餐。attention 对 token 数敏感,patch size 一小,计算量马上上去。512×512 下 DiT-XL/2 处理 1024 个 latent tokens,Gflops 达到 524.6。更高分辨率或视频任务如果直接照搬,会遇到序列长度问题。因此后续工作经常会引入更高效的 attention、factorized attention、latent compression 或分层结构。

一个简化版 DiT mental model

如果你想快速在脑子里跑一遍 DiT,可以用下面这个 mental model:

1. VAE 把图像压成小 latent map。 2. diffusion 给 latent 加噪声,得到 zt。 3. DiT 把 zt 切成 patch tokens。 4. timestep + label 变成一个条件向量 c。 5. 每个 Transformer block 用 c 生成 adaLN 的 shift/scale/gate。 6. token 经过 self-attention 和 MLP,预测噪声与方差。 7. 采样循环反复调用 DiT,把纯噪声 latent 还原成可解码 latent。 8. VAE decoder 把 latent 变回图像。

如果再压缩一点:DiT = LDM 的 latent 空间 + ViT patch tokens + adaLN-Zero 条件注入 + compute scaling。

这个公式比“U-Net 换 Transformer”更有信息量。因为只换 Transformer 不够,必须同时解释 latent patch 怎么组织、条件怎么进 block、为什么 zero init 稳定、以及质量为什么跟 Gflops 强相关。

局限和阅读建议

DiT 原论文的实验场景是 ImageNet class-conditional generation,不是今天更常见的开放词表 text-to-image。它证明了 Transformer backbone 在扩散图像生成里可行,也给了清晰的缩放证据,但没有解决文本对齐、复杂 prompt following、超高分辨率生成或视频长程一致性。

读这篇论文时,建议不要只盯着最终 FID。更值得反复看的有三处:Figure 3 的 block design,Figure 6/8 的 scaling 曲线,以及官方models.py里的 adaLN-Zero 实现。看完这三处,DiT 的核心基本就通了。

后续如果继续读,可以沿两条线走。一条是架构线:PixArt、MMDiT、Stable Diffusion 3 这类模型如何把文本条件和 DiT backbone 结合起来。另一条是效率线:FlashAttention、sequence length reduction、latent tokenization、video DiT 如何处理更长序列。DiT 本身是起点,不是终点。

参考资料

  • William Peebles, Saining Xie, “Scalable Diffusion Models with Transformers”, arXiv:2212.09748 / ICCV 2023, retrieved 2026-06-30, https://arxiv.org/abs/2212.09748
  • ICCV 2023 open access paper PDF, retrieved 2026-06-30, https://openaccess.thecvf.com/content/ICCV2023/papers/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.pdf
  • DiT official project page, retrieved 2026-06-30, https://www.wpeebles.com/DiT
  • facebookresearch/DiT official PyTorch implementation, retrieved 2026-06-30, https://github.com/facebookresearch/DiT
  • facebookresearch/DiTmodels.py, retrieved 2026-06-30, https://github.com/facebookresearch/DiT/blob/main/models.py
http://www.jsqmd.com/news/1101970/

相关文章:

  • Anthropic模型能力演进与访问控制机制解析
  • 曲直天涯路
  • 从波形到中断:一篇看懂 I2C 通信原理、地址、ACK 与调试方法
  • 汽车级MCU评估板硬件设计解析:电源、时钟与调试接口实战
  • Bombesin (8-14) ;WAVGHLM-NH₂
  • iOS激活锁免费绕过教程:5步解锁iPhone 6s-X设备
  • ASD433A评估板硬件设计解析与PowerPC MCU开发实战指南
  • 2026申博机构交付颗粒度测评|从落地精细度甄别正规辅导平台
  • MuleSoft+LangChain企业级AI编排实战:打通LLM与CRM/ERP
  • 嵌入式定位导航:PIC18F86J15与13DOF传感器融合方案
  • 基于WSEN-ISDS和MKV44F128的6DOF运动追踪系统实现
  • 方向科技 GEO 系统与市面 AI 搜索优化软件深度横评
  • XSS漏洞实战指南:从原理到防御的Web安全必修课
  • Three.js 官方选择辉光简化版教程
  • 国产大模型会回答之后,怎样用魔珐星云补齐具象交互?
  • 【小白也能轻松玩转龙虾】虾壳云一键部署轻量化 AI,低配设备流畅运行 OpenClaw v2.7.9(附最新安装包)
  • PowerPC评估板ASD433A硬件设计解析与调试实战
  • 3分钟实现Windows桌面分区革命:NoFences开源桌面管理终极方案
  • Visual C++运行库终极指南:一键解决Windows软件依赖问题
  • Codex 实战:从基础调用到稳定运行
  • 权限状态机与渐进式授权:从用户体验到子 Agent 代理
  • 云服务器SSRF漏洞利用IMDS窃取IAM凭证的攻防实战
  • UniExtract2:终极文件解压工具,一键提取500+种格式的完整指南
  • 花箱花坛花槽花钵哪家好?优质靠谱供应商挑选实用指南
  • 【仅限前500名开发者】OpenAI发布会技术密钥包:含Model Context Protocol v2规范、Rate Limiting 3.0策略表、Error Code映射速查表
  • 终极CSV查看指南:用csview快速美化你的数据表格
  • 测试内容测试内容测试内容
  • 微信网页版解锁插件:5分钟解决Chrome/Firefox/Edge无法登录问题
  • Sora已上线全球公测,可灵AI却悄然升级V2.3——两大平台训练成本、推理延迟、版权合规性全对比,现在不看就晚了!
  • HTML 早已不是标签了,它现在是系统级接口:这 9 个 API 直接干翻常用 JS 库 _