DiT(Diffusion Transformer)形象讲解(建议先看懂前几篇文章)
文章目录
- 一、为什么图像生成不再是“写作文”,而是“洗照片”?
- 1. 核心比喻:DiT 像一群修复师在擦一幅被雪花盖住的海报
- 二、先建立全局观:DiT 到底由哪几部分组成?
- 1. VAE:先把高清大图压成“浓缩草稿”
- 2. Forward Diffusion:故意往草稿上泼噪声
- 3. Patchify:把 latent 切成一个个 patch token
- 4. Position Embedding:告诉模型“每块在画面的哪里”
- 5. Timestep / Condition Embedding:告诉模型“现在是第几轮去噪”
- 6. DiT Blocks:让所有 patch 进行全局群聊
- 三、DiT 最核心的比喻:不是“逐字续写”,而是“整图会诊”
- 1. 图像里的 Q、K、V 到底在表达什么?
- 四、数学上它到底在预测什么?不是类别,而是“噪声”
- 五、为什么 DiT 不用 causal mask?这点和 GPT 完全不同
- 1. GPT 为什么必须 mask?
- 2. DiT 为什么不需要?
- 六、DiT 最关键的条件注入:adaLN 和 adaLN-Zero 到底是什么?
- 1. 先回忆普通 LayerNorm 是什么
- 2. adaLN:把“老师指令”变成每层的调音旋钮
- 3. adaLN-Zero:为什么最后还要加个 Zero?
- 4. 直观比喻:新员工先旁听,再逐步接手工作
- 七、DiT Block 长什么样?和标准 Transformer Block 有什么不同?
- 八、DiT 和 U-Net 到底有什么结构差异?
- 1. U-Net:像金字塔式的卷积施工队
- 2. DiT:像一群平级专家在全图会议室里反复会诊
- 3. 粗暴对比
- 九、训练和采样流程,一口气串起来看
- 1. 训练时:故意弄脏,再逼模型学会清洗
- 2. 采样时:从纯噪声开始,反复显影
- 十、极简 PyTorch:手搓一个 Mini DiT 骨架
- 观察这段代码时,你重点盯 4 件事
- 十三、下一步该学什么?
本文通过直观的比喻和极简的 PyTorch 代码,彻底拆解 DiT(Diffusion Transformer)的核心——它到底是如何把“满屏雪花噪声”,一步步变成清晰图像的。
一、为什么图像生成不再是“写作文”,而是“洗照片”?
如果你已经理解了 Transformer 做文本生成,那么你脑子里很容易先冒出一个想法:
既然文本可以一个 token 一个 token 地往后写,那图像是不是也能一个像素一个像素地往后写?
理论上可以,但图像和文本有一个巨大的不同:
- 文本天然是一条序列,词有明显的前后顺序。
- 图像天然是一个二维平面,左上角、右下角、远处背景和近处主体是同时存在、彼此耦合的。
如果你硬要像写作文一样,一个像素一个像素地生成图片,就会非常别扭:
- 生成左眼时,还没看到右眼,脸容易歪。
- 生成前景人物时,还没全局理解背景,光影容易乱。
- 图像的“整体一致性”很难保证。
所以扩散模型(Diffusion Model)走了另一条完全不同的路线:
它不从左到右“写”图,而是从一张全是噪声的脏图开始,一轮一轮把噪声擦掉,让图像逐渐“显影”出来。
1. 核心比喻:DiT 像一群修复师在擦一幅被雪花盖住的海报
想象你面前有一张海报,但海报表面被泼满了白色雪花噪点,几乎看不清内容。
这时来了一群修复师,他们不是每次只修一个像素,而是:
- 先把海报切成很多小方块
- 每个修复师负责看一块
- 所有修复师先开会,互相交流全局情况
- 然后每人决定:“我这一块,这一轮该擦掉多少噪声?”
- 擦掉一点后,再开下一轮会
一轮轮下来,原本满是雪花的海报就逐渐显出轮廓、颜色、细节。
这就是 DiT 的本质:
DiT 不是直接“画图”,而是在“反复开会,协同去噪”。
二、先建立全局观:DiT 到底由哪几部分组成?
如果粗暴地用一句话概括 DiT:
DiT = 在 latent 空间里,把图像切成 patch token,再用 Transformer 反复预测噪声。
它大致可以拆成 6 个零件:
1. VAE:先把高清大图压成“浓缩草稿”
真实图片太大了。
比如一张256 x 256 x 3的 RGB 图,直接让 Transformer 对它做全局注意力,成本很高。
所以 DiT 通常不会直接处理原始像素,而是先用一个VAE Encoder把图片压缩到 latent 空间:
- 原图:
256 x 256 x 3 - latent:
32 x 32 x 4
这就像先把高清海报压成一张“浓缩草稿纸”:
- 分辨率小了很多
- 但主要信息还保留着
- 后面只需要在这张草稿上做去噪
2. Forward Diffusion:故意往草稿上泼噪声
训练时,模型不是直接看干净图,而是故意把 latent 草稿弄脏。
设干净 latent 是z zz,随机采样一个时间步t tt,往里面加上高斯噪声,得到z t z_tzt。
你可以把这一步理解成:
老师先把一张草稿图泼脏,再拿给学生看:“你来猜猜我刚才泼了多少脏水。”
3. Patchify:把 latent 切成一个个 patch token
到了 Transformer 这里,不能直接吃二维特征图,它更擅长处理 token 序列。
所以要把 latent 切成小块 patch。
例如:
- latent 大小:
32 x 32 x 4 - patch size:
2 x 2
那就会得到:
- patch 数量:
(32/2) * (32/2) = 256 - 每个 patch 的原始维度:
2 * 2 * 4 = 16
然后再通过一个线性层,把每个 patch 投影成统一维度的 token embedding。
于是:
一张图,就变成了一串图像 token。
这一步非常像 NLP:
- 一句话 -> 一串词 token
- 一张 latent 图 -> 一串 patch token
4. Position Embedding:告诉模型“每块在画面的哪里”
如果没有位置编码,Transformer 只知道自己拿到了 256 个 patch,但不知道:
- 谁在左上角
- 谁在中心
- 谁在右下角
这就像你把拼图块全倒在桌子上,但没告诉模型哪块原来属于天空、哪块属于脸。
所以 DiT 也必须给 patch token 加上位置编码。
区别只是:
- 文本位置编码强调一维顺序
- 图像位置编码强调二维空间位置
5. Timestep / Condition Embedding:告诉模型“现在是第几轮去噪”
这一点是 DiT 和普通 ViT 最大的不同之一。
DiT 每次输入的不是固定图像,而是某个噪声阶段的图z t z_tzt。
所以模型必须知道:
- 现在是早期,噪声很重?
- 还是后期,图已经很清晰?
- 这张图要求生成“猫”还是“宇航员”?
因此 DiT 会把这些条件编码成向量:
- 时间步 embedding:告诉模型现在在第几轮
- 类别/文本 embedding:告诉模型想生成什么
你可以把它理解成每轮开会前,老师先发指令:
- “现在噪声很大,先抓轮廓。”
- “现在接近收尾,重点修眼睛、纹理和光影。”
- “这张图的目标是金毛犬,不是猫。”
6. DiT Blocks:让所有 patch 进行全局群聊
进入主体后,就是一层层 Transformer Block。
每个 patch 都会生成自己的:
- Q(我需要什么信息)
- K(我能提供什么线索)
- V(我真正携带的内容)
然后所有 patch 一起做 Self-Attention,全局互相交流。
这意味着:
- 左上角的一块天空,可以直接参考右边的太阳
- 脸部 patch 可以直接和肩膀、头发、背景光线交流
- 前景和背景可以在同一层里完成全局协调
这就是 Transformer 放到图像生成里的最大魅力:
全图所有区域都能直接“开群聊”,不是只能看局部邻居。
三、DiT 最核心的比喻:不是“逐字续写”,而是“整图会诊”
文本 Transformer 更像:
一群词在讨论,下一句该写哪个词。
DiT 更像:
一群 patch 在讨论,这一轮每块应该擦掉多少噪声。
1. 图像里的 Q、K、V 到底在表达什么?
假设画面里有一只猫,背景是沙发。
某个 patch 位于猫眼附近,它的 Q 可能在“想”:
- “我现在需要确认自己到底属于眼睛、脸毛,还是背景阴影?”
沙发区域的 patch 的 K 可能在“展示”:
- “我是大块的平坦纹理,像背景,不像五官。”
猫耳朵附近的 patch 的 K 可能在“展示”:
- “我属于猫头的边缘结构,和眼睛区域关系密切。”
于是猫眼 patch 在做注意力时,会重点关注:
- 邻近的脸部 patch
- 耳朵 patch
- 光照相关的背景 patch
而不会把大量注意力浪费在无关区域。
最终这个 patch 就能更清楚地判断:
“我这里应该保留猫眼结构,不该继续像噪声一样模糊。”
所以 Self-Attention 在 DiT 里干的事情,不是语言里的“上下文消歧”,而是:
在全局画面中,确定每个局部到底应该长成什么样。
四、数学上它到底在预测什么?不是类别,而是“噪声”
这是初学 DiT 最容易混淆的地方。
分类网络最后预测的是:
- 这是一只猫
- 这是一辆车
GPT 最后预测的是:
- 下一个 token 是谁
但 DiT 最后预测的通常是:
- 当前图里加进去的噪声ϵ \epsilonϵ
- 或者等价形式v vv
它不是直接输出“最终图像”,而是输出:
“这张带噪图里,哪些部分像噪声,应该减掉多少。”
于是训练目标通常是一个简单的均方误差:
L = ∥ ϵ ^ − ϵ ∥ 2 \mathcal{L} = \| \hat{\epsilon} - \epsilon \|^2L=∥ϵ^−ϵ∥2
其中:
- ϵ \epsilonϵ:真实加进去的噪声
- ϵ ^ \hat{\epsilon}ϵ^:DiT 预测出来的噪声
这就像老师真的往图上泼了一桶脏水,而模型要尽量精确地回答:
“你刚才泼的是这一桶脏水,对吧?”
如果它猜得越来越准,反过来就越来越会“洗照片”。
五、为什么 DiT 不用 causal mask?这点和 GPT 完全不同
你学过文本 Transformer 后,脑子里很容易有个惯性:
Transformer 不是都要 mask 吗?
答案是:DiT 不需要 GPT 那种 causal mask。
1. GPT 为什么必须 mask?
因为 GPT 是做自回归生成:
- 预测第 5 个词时,不能偷看第 6 个词
- 否则训练时等于作弊
所以它必须使用下三角 mask,把未来位置遮住。
2. DiT 为什么不需要?
因为 DiT 的任务不是“预测未来 token”,而是:
给定一整张带噪图,同时判断所有位置的噪声。
也就是说,在 DiT 的每一步里:
- 所有 patch 都已经同时存在
- 大家可以彼此看见
- 没有“未来 token 不能看”的约束
这就像全体修复师围着同一张海报开会:
- 左上角的修复师当然可以看右下角
- 眼睛区域当然可以参考嘴巴和耳朵
- 背景当然可以参考主体的轮廓
所以 DiT 使用的是普通双向自注意力,不是自回归的单向 mask 注意力。
一句话记忆:
GPT 的时间轴在“词序”上,所以怕偷看未来。
DiT 的时间轴在“去噪轮数”上,所以同一轮里所有 patch 都能全局互看。
六、DiT 最关键的条件注入:adaLN 和 adaLN-Zero 到底是什么?
这部分是 DiT 真正有“灵魂细节”的地方。
因为仅仅把时间步 embedding 加到输入里,还不够强。
模型更希望在每一层都知道:
- 当前噪声阶段是多少
- 当前条件是什么
于是 DiT 引入了Adaptive LayerNorm(adaLN),以及更常见的adaLN-Zero。
1. 先回忆普通 LayerNorm 是什么
普通 LayerNorm 会对一个 token 的特征做标准化:
LN ( x ) = γ ⋅ x − μ σ + β \text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \betaLN(x)=γ⋅σx−μ+β
其中:
- μ , σ \mu, \sigmaμ,σ:由当前 token 自己算出来
- γ , β \gamma, \betaγ,β:是固定可学习参数
作用就是把数值拉稳。
2. adaLN:把“老师指令”变成每层的调音旋钮
在 adaLN 中,γ \gammaγ和β \betaβ不再是固定参数,而是由条件向量动态生成的。
条件向量可以来自:
- timestep embedding
- class embedding
- text embedding
也就是说,不同时间步、不同生成目标,会让 LayerNorm 的缩放和平移方式都不同。
这就像每层前面多了一个“总指挥调音台”:
- 现在是早期噪声阶段 -> 把模型调到更关注大轮廓
- 现在是后期精修阶段 -> 把模型调到更关注局部细节
- 现在目标是狗 -> 强化和狗相关的结构模式
3. adaLN-Zero:为什么最后还要加个 Zero?
DiT 论文里常用的是adaLN-Zero。
它的关键思想是:
让每个块一开始先“几乎什么都不做”,训练时再慢慢学会如何介入。
具体做法可以粗略理解成:
- 条件向量不只生成 shift / scale
- 还生成一个 gate(门控系数)
- 这个 gate 初始接近 0
于是某个 block 的残差分支一开始近似是关闭的:
x ← x + gate ⋅ Sublayer ( adaLN ( x , c ) ) x \leftarrow x + \text{gate} \cdot \text{Sublayer}(\text{adaLN}(x, c))x←x+gate⋅Sublayer(adaLN(x,c))
当gate ≈ 0时,这层一开始对主干影响很小,训练更稳定。
4. 直观比喻:新员工先旁听,再逐步接手工作
普通残差块像一个新员工第一天上班,就直接大幅改流程,容易把系统搞乱。
adaLN-Zero 像是:
- 新员工第一天先只旁听,不随便动系统
- 随着训练进行,模型逐渐学会什么时候该介入、介入多少
这能显著提升大模型训练的稳定性。
七、DiT Block 长什么样?和标准 Transformer Block 有什么不同?
如果用最简化视角看,DiT Block 还是那两个老朋友:
- Self-Attention
- MLP
但每个子层前面会加上条件调制(adaLN),并且通常带门控。
一个非常典型的简化版流程是:
- 对输入 token 做 adaLN
- 送入 Self-Attention
- 用 gate 控制这条残差分支的强度
- 再对主干做一次 adaLN
- 送入 MLP
- 再用 gate 控制 MLP 分支的强度
你可以把它理解成:
标准 Transformer Block 是“大家开会 + 各自思考”。
DiT Block 则是“老师先发这轮指令,再开会,再思考”。
八、DiT 和 U-Net 到底有什么结构差异?
在 DiT 爆火之前,扩散模型最经典的 backbone 是 U-Net。
两者都能做扩散去噪,但世界观不太一样。
1. U-Net:像金字塔式的卷积施工队
U-Net 的典型特点是:
- 下采样:不断压缩分辨率,提取高层语义
- 上采样:逐步恢复分辨率
- 跳跃连接:把浅层细节直接传给深层
它像一支分层施工队:
- 小工先处理局部纹理
- 中层处理结构
- 高层处理全局语义
- 最后再逐级还原回来
2. DiT:像一群平级专家在全图会议室里反复会诊
DiT 的核心则是:
- 把图切成 patch
- 直接作为 token 序列处理
- 每层都用全局 self-attention
它不像 U-Net 那样强依赖多尺度卷积金字塔,而是更像:
所有 patch 从一开始就处在同一个大会场里,每一层都能做全局交流。
3. 粗暴对比
- U-Net 强在卷积归纳偏置强,天生适合图像局部结构
- DiT 强在 Transformer 扩展性好,模型做大后全局建模能力强
- U-Net 像“图像工程老将”
- DiT 像“Transformer 化的新一代主力”
所以很多人会把它们类比成:
- U-Net:扩散时代的卷积王者
- DiT:扩散时代的 Transformer 主力
九、训练和采样流程,一口气串起来看
1. 训练时:故意弄脏,再逼模型学会清洗
训练流程可以概括为:
- 输入真实图片x xx
- 用 VAE 编码成 latentz zz
- 随机采样时间步t tt
- 加噪得到z t z_tzt
- patchify + 位置编码
- 加入时间步和类别/文本条件
- 送入 DiT,预测噪声ϵ ^ \hat{\epsilon}ϵ^
- 和真实噪声ϵ \epsilonϵ做 MSE loss
2. 采样时:从纯噪声开始,反复显影
真正生成图像时,没有真实图片,只有随机噪声:
- 从纯高斯噪声 latent 开始
- 把当前z t z_tzt送入 DiT
- DiT 预测这一轮的噪声
- 调度器(scheduler)据此把图变干净一点
- 重复很多轮
- 最后得到干净 latent
- 用 VAE Decoder 解码成最终图片
一句话总结:
训练是“老师泼脏水,学生学会认脏水”。
采样是“学生独立洗照片”。
十、极简 PyTorch:手搓一个 Mini DiT 骨架
下面这段代码不是工业级实现,但足够帮助你把 DiT 的核心骨架真正串起来。
importtorchimporttorch.nnasnnclassPatchEmbed(nn.Module):def__init__(self,in_channels=4,patch_size=2,hidden_size=256):super().__init__()self.patch_size=patch_size# 用卷积做 patchify,等价于切块后线性投影self.proj=nn.Conv2d(in_channels,hidden_size,kernel_size=patch_size,stride=patch_size)defforward(self,x):# x: (B, C, H, W)x=self.proj(x)# (B, D, H/P, W/P)x=x.flatten(2).transpose(1,2)# (B, N, D)returnxclassTimestepEmbedder(nn.Module):def__init__(self,hidden_size):super().__init__()self.mlp=nn.Sequential(nn.Linear(hidden_size,hidden_size),nn.SiLU(),nn.Linear(hidden_size,hidden_size),)defforward(self,t_embed):# 这里假设外部已经把 t 变成 hidden_size 维向量returnself.mlp(t_embed)classAdaLNModulation(nn.Module):def__init__(self,hidden_size):super().__init__()# 这里一次性生成 6 组参数:# attn 的 shift/scale/gate + mlp 的 shift/scale/gateself.net=nn.Sequential(nn.SiLU(),nn.Linear(hidden_size,6*hidden_size))defforward(self,cond):returnself.net(cond).chunk(6,dim=-1)classDiTBlock(nn.Module):def__init__(self,hidden_size=256,num_heads=8,mlp_ratio=4.0):super().__init__()self.norm1=nn.LayerNorm(hidden_size,elementwise_affine=False)self.norm2=nn.LayerNorm(hidden_size,elementwise_affine=False)self.attn=nn.MultiheadAttention(hidden_size,num_heads,batch_first=True)mlp_hidden=int(hidden_size*mlp_ratio)self.mlp=nn.Sequential(nn.Linear(hidden_size,mlp_hidden),nn.GELU(),nn.Linear(mlp_hidden,hidden_size))self.adaLN=AdaLNModulation(hidden_size)defmodulate(self,x,shift,scale):returnx*(1+scale.unsqueeze(1))+shift.unsqueeze(1)defforward(self,x,cond):shift_msa,scale_msa,gate_msa,shift_mlp,scale_mlp,gate_mlp=self.adaLN(cond)# Attention 子层x_norm=self.modulate(self.norm1(x),shift_msa,scale_msa)attn_out,_=self.attn(x_norm,x_norm,x_norm)x=x+gate_msa.unsqueeze(1)*attn_out# MLP 子层x_norm=self.modulate(self.norm2(x),shift_mlp,scale_mlp)mlp_out=self.mlp(x_norm)x=x+gate_mlp.unsqueeze(1)*mlp_outreturnxclassFinalLayer(nn.Module):def__init__(self,hidden_size=256,patch_size=2,out_channels=4):super().__init__()self.norm=nn.LayerNorm(hidden_size,elementwise_affine=False)self.linear=nn.Linear(hidden_size,patch_size*patch_size*out_channels)self.adaLN=nn.Sequential(nn.SiLU(),nn.Linear(hidden_size,2*hidden_size))defforward(self,x,cond):shift,scale=self.adaLN(cond).chunk(2,dim=-1)x=self.norm(x)x=x*(1+scale.unsqueeze(1))+shift.unsqueeze(1)x=self.linear(x)returnxclassMiniDiT(nn.Module):def__init__(self,input_size=32,in_channels=4,patch_size=2,hidden_size=256,depth=6,num_heads=8):super().__init__()self.patch_size=patch_size self.in_channels=in_channels self.x_embedder=PatchEmbed(in_channels,patch_size,hidden_size)num_patches=(input_size//patch_size)**2self.pos_embed=nn.Parameter(torch.zeros(1,num_patches,hidden_size))self.t_embedder=TimestepEmbedder(hidden_size)self.blocks=nn.ModuleList([DiTBlock(hidden_size,num_heads)for_inrange(depth)])self.final_layer=FinalLayer(hidden_size,patch_size,in_channels)defunpatchify(self,x):# x: (B, N, P*P*C)B,N,_=x.shape P=self.patch_size C=self.in_channels H=W=int(N**0.5)x=x.view(B,H,W,P,P,C)x=x.permute(0,5,1,3,2,4).contiguous()x=x.view(B,C,H*P,W*P)returnxdefforward(self,z_t,t_embed):# 1. 图像转 patch tokenx=self.x_embedder(z_t)+self.pos_embed# 2. 时间步条件向量cond=self.t_embedder(t_embed)# 3. 一层层 DiT Blockforblockinself.blocks:x=block(x,cond)# 4. 输出每个 patch 的噪声预测,再拼回二维 latentx=self.final_layer(x,cond)x=self.unpatchify(x)returnx观察这段代码时,你重点盯 4 件事
PatchEmbed
它把二维 latent 变成 token 序列。pos_embed
它告诉模型每个 patch 在图上的空间位置。t_embedder + adaLN
它把“第几轮去噪”这个条件注入到每一个 block。unpatchify
它把 token 再还原回二维 latent 噪声预测图。
十三、下一步该学什么?
如果你已经看懂了这篇,那么下一步最值得继续深挖的是 4 个点:
扩散公式本身
彻底搞懂q ( x t ∣ x 0 ) q(x_t|x_0)q(xt∣x0)、反向采样、scheduler、DDPM/DDIM。Classifier-Free Guidance
为什么一句提示词能把图“拉向”你想要的方向。Latent Diffusion
为什么先进 VAE latent 再扩散,会极大降低计算量。Stable Diffusion / Flux / PixArt 这类模型
看它们如何把文本编码器、VAE、DiT/U-Net 拼成完整系统。
如果你愿意,我下一条可以继续按同样风格给你写:
- 《为什么扩散模型能从噪声里还原图像?》
- 或者《Classifier-Free Guidance 到底在“拉”什么?》
- 或者《DiT vs U-Net:一篇图像生成骨干网络进化史》
