DDPM 扩散模型 PyTorch 实现:10步代码解析前向与逆向过程核心
DDPM 扩散模型 PyTorch 实现:10步代码解析前向与逆向过程核心
扩散模型(Diffusion Model)近年来在图像生成领域掀起了一场革命。与GAN和VAE不同,扩散模型通过一个渐进的加噪和去噪过程来生成高质量图像。本文将带你从PyTorch实现的角度,深入理解DDPM(Denoising Diffusion Probabilistic Models)的核心机制。
1. 扩散模型基础概念
扩散模型的核心思想包含两个过程:
- 前向过程(扩散过程):逐步对图像添加高斯噪声,最终将图像完全转化为噪声
- 逆向过程(去噪过程):学习如何从噪声中逐步恢复原始图像
这两个过程都是马尔可夫链,其中每一步只依赖于前一步的状态。扩散模型的神奇之处在于,它通过学习这个逆向过程,可以从纯噪声开始生成全新的图像。
在PyTorch实现中,我们需要关注几个关键参数:
# 典型参数设置 T = 1000 # 扩散步数 beta_start = 0.0001 beta_end = 0.02 betas = torch.linspace(beta_start, beta_end, T) alphas = 1 - betas alpha_bars = torch.cumprod(alphas, dim=0)2. 前向扩散过程实现
前向过程的核心函数是q_sample,它实现了从x₀一步到位计算xₜ的功能:
def q_sample(x0, t, noise=None): """ 一步到位计算x_t :param x0: 原始图像 [batch_size, channels, height, width] :param t: 时间步 [batch_size] :param noise: 可选的外部噪声 :return: 加噪后的图像x_t """ if noise is None: noise = torch.randn_like(x0) # 计算alpha_bar_t的平方根 [batch_size, 1, 1, 1] sqrt_alpha_bar_t = extract(alpha_bars.sqrt(), t, x0.shape) # 计算1-alpha_bar_t的平方根 sqrt_one_minus_alpha_bar_t = extract((1 - alpha_bars).sqrt(), t, x0.shape) return sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise这里的关键数学原理是:
x_t = √(ᾱₜ)x₀ + √(1-ᾱₜ)ε其中ᾱₜ=∏ᵢαᵢ,αᵢ=1-βᵢ
辅助函数extract用于从序列中按时间步t提取值:
def extract(arr, t, x_shape): """ 从arr中按索引t提取值,并reshape到匹配x_shape """ batch_size = t.shape[0] out = arr.gather(-1, t) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))3. 逆向去噪过程实现
逆向过程的核心是p_sample函数,它实现了从xₜ预测xₜ₋₁的一步:
def p_sample(model, x, t, t_index): """ 从x_t预测x_{t-1} :param model: 噪声预测模型 :param x: 当前图像x_t :param t: 当前时间步 :param t_index: 时间步索引 :return: x_{t-1} """ betas_t = extract(betas, t, x.shape) sqrt_one_minus_alpha_bar_t = extract((1 - alpha_bars).sqrt(), t, x.shape) sqrt_recip_alpha_t = extract(torch.sqrt(1 / alphas), t, x.shape) # 模型预测噪声 pred_noise = model(x, t) # 计算均值 model_mean = sqrt_recip_alpha_t * (x - betas_t * pred_noise / sqrt_one_minus_alpha_bar_t) if t_index == 0: return model_mean else: posterior_variance_t = extract(posterior_variance, t, x.shape) noise = torch.randn_like(x) return model_mean + torch.sqrt(posterior_variance_t) * noise逆向过程的数学原理基于:
x_{t-1} = 1/√αₜ (xₜ - βₜ/√(1-ᾱₜ)εθ(xₜ,t)) + σₜz4. 噪声预测模型架构
DDPM通常使用U-Net架构来预测噪声:
class UNet(nn.Module): def __init__(self, dim=64, dim_mults=(1, 2, 4, 8)): super().__init__() # 时间嵌入 self.time_embed = nn.Sequential( nn.Linear(64, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim * 4) ) # 下采样路径 self.down_blocks = nn.ModuleList([ ConvBlock(3, dim), DownBlock(dim, dim * 2), DownBlock(dim * 2, dim * 4), DownBlock(dim * 4, dim * 8) ]) # 中间块 self.mid_block = nn.Sequential( ResBlock(dim * 8, dim * 8), AttentionBlock(dim * 8), ResBlock(dim * 8, dim * 8) ) # 上采样路径 self.up_blocks = nn.ModuleList([ UpBlock(dim * 8, dim * 4), UpBlock(dim * 4, dim * 2), UpBlock(dim * 2, dim) ]) # 最终卷积 self.final_conv = nn.Conv2d(dim, 3, kernel_size=1) def forward(self, x, t): # 时间嵌入 t_emb = sinusoidal_embedding(t) t_emb = self.time_embed(t_emb) # 下采样 h = [] for block in self.down_blocks: x = block(x, t_emb) h.append(x) x = F.avg_pool2d(x, 2) # 中间块 x = self.mid_block(x, t_emb) # 上采样 for block in self.up_blocks: x = F.interpolate(x, scale_factor=2, mode='nearest') x = torch.cat([x, h.pop()], dim=1) x = block(x, t_emb) return self.final_conv(x)5. 训练过程实现
DDPM的训练目标是最小化预测噪声和实际噪声的均方误差:
def train(model, dataloader, optimizer, device, epochs): model.train() for epoch in range(epochs): for batch, _ in dataloader: batch = batch.to(device) # 随机采样时间步 t = torch.randint(0, T, (batch.size(0),), device=device) # 生成噪声 noise = torch.randn_like(batch) # 前向过程加噪 noisy_images = q_sample(batch, t, noise) # 预测噪声 pred_noise = model(noisy_images, t) # 计算损失 loss = F.mse_loss(pred_noise, noise) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step()6. 图像生成过程
训练完成后,我们可以从纯噪声开始逐步生成图像:
@torch.no_grad() def p_sample_loop(model, shape, device): # 从纯噪声开始 img = torch.randn(shape, device=device) for i in reversed(range(T)): t = torch.full((shape[0],), i, device=device, dtype=torch.long) img = p_sample(model, img, t, i) return img def generate(model, n_samples=16, device='cuda'): # 生成样本 samples = p_sample_loop( model, (n_samples, 3, 32, 32), # 假设生成32x32图像 device ) return samples7. 关键数学推导简化
理解DDPM需要掌握几个核心数学概念:
前向过程分布:
q(x_t|x_0) = N(x_t; √(ᾱₜ)x_0, (1-ᾱₜ)I)逆向过程分布:
p_θ(x_{t-1}|x_t) = N(x_{t-1}; μ_θ(x_t,t), Σ_θ(x_t,t))损失函数(简化形式):
L = E_{t,x_0,ε}[||ε - ε_θ(x_t,t)||^2]
8. 实际应用技巧
在实现DDPM时,有几个实用技巧:
- 噪声调度:βₜ的选择对结果影响很大,通常使用线性或余弦调度
- 时间步嵌入:使用正弦位置编码将时间步t嵌入到高维空间
- 梯度裁剪:训练时对梯度进行裁剪可以稳定训练过程
# 余弦调度示例 def cosine_beta_schedule(timesteps, s=0.008): steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999)9. 性能优化策略
为了提高DDPM的效率和生成质量,可以考虑以下策略:
- 重要性采样:根据时间步的重要性调整采样频率
- 加速采样:减少采样步数而不显著降低质量
- 混合精度训练:使用FP16加速训练过程
# 混合精度训练示例 scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): pred_noise = model(noisy_images, t) loss = F.mse_loss(pred_noise, noise) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()10. 完整代码结构
一个完整的DDPM实现通常包含以下文件结构:
ddpm/ ├── model.py # U-Net模型定义 ├── diffusion.py # 前向和逆向过程实现 ├── train.py # 训练脚本 ├── generate.py # 生成脚本 └── utils.py # 辅助函数扩散模型代表了生成模型的一个重要方向,通过理解这些核心代码,你可以更好地掌握其工作原理,并在此基础上进行改进和创新。
