别再死磕公式了!用Python和PyTorch手把手复现DDPM图像去噪(附完整代码)
从零构建DDPM:Python与PyTorch实战图像去噪
在计算机视觉领域,扩散模型正迅速成为生成高质量图像的主流方法。本文将带您从零开始,使用PyTorch框架完整实现一个基础的Denoising Diffusion Probabilistic Model(DDPM),无需深入复杂的数学推导,通过代码直观理解这一强大模型的工作原理。
1. 扩散模型基础概念
扩散模型的核心思想是通过逐步添加噪声破坏图像,再学习逆向去噪过程。想象一下把一杯清水慢慢滴入墨水的过程——扩散模型的正向过程就如同这个"污染"过程,而逆向过程则是神奇的"净化"操作。
与传统GAN或VAE不同,DDPM具有几个独特优势:
- 训练稳定性:不依赖对抗训练,避免了模式坍塌问题
- 生成质量:逐步细化生成过程,能产生更自然的高频细节
- 理论优雅:基于热力学的非平衡统计物理基础
在技术实现层面,DDPM主要包含两个关键阶段:
- 前向扩散过程(Fixed Markov Chain):逐步向数据添加高斯噪声
- 逆向去噪过程(Learned Transition):训练神经网络逐步去噪
# 基础配置 import torch import torch.nn as nn import numpy as np from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu")2. 前向扩散过程实现
前向过程定义为马尔可夫链,逐步将数据转化为各向同性高斯分布。关键在于设计合理的噪声调度(noise schedule),控制不同时间步的噪声添加量。
2.1 噪声调度设计
我们采用线性噪声调度,定义从β₁=1e-4到β_T=0.02的线性增长序列:
def linear_beta_schedule(timesteps, start=1e-4, end=0.02): return torch.linspace(start, end, timesteps) T = 1000 # 总时间步数 betas = linear_beta_schedule(T) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) # α的连乘积 sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)2.2 单步扩散实现
给定原始图像x₀和时间步t,计算加噪后的图像x_t:
def q_sample(x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) sqrt_alpha_cumprod_t = sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise可视化不同时间步的加噪效果:
def plot_diffusion_process(image, num_steps=5): plt.figure(figsize=(15, 3)) plt.subplot(1, num_steps+1, 1) plt.imshow(image.squeeze(), cmap='gray') plt.title("Original") plt.axis('off') for i in range(1, num_steps+1): t = torch.tensor([i*(T//num_steps)-1]) noisy_image = q_sample(image, t) plt.subplot(1, num_steps+1, i+1) plt.imshow(noisy_image.squeeze().cpu().numpy(), cmap='gray') plt.title(f"Step {t.item()+1}") plt.axis('off') plt.show()3. 逆向去噪模型构建
逆向过程的核心是训练一个噪声预测网络。我们采用改进的U-Net架构,包含下采样和上采样路径,并加入时间步嵌入。
3.1 时间步嵌入
将离散时间步转换为连续向量表示:
class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): device = t.device half_dim = self.dim // 2 embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = t[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings3.2 基础残差块
class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.act = nn.SiLU() self.bn = nn.BatchNorm2d(out_ch) def forward(self, x, t): h = self.bn(self.act(self.conv1(x))) time_emb = self.act(self.time_mlp(t)) h = h + time_emb.reshape(-1, h.shape[1], 1, 1) return self.act(self.conv2(h))3.3 完整U-Net实现
class UNet(nn.Module): def __init__(self, in_channels=1, out_channels=1, dim=32, dim_mults=(1, 2, 4, 8)): super().__init__() self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim) ) dims = [in_channels] + [dim * m for m in dim_mults] self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) # 下采样路径 for i in range(len(dims)-1): self.downs.append(Block(dims[i], dims[i+1], dim)) # 中间层 self.mid = Block(dims[-1], dims[-1], dim) # 上采样路径 for i in reversed(range(len(dims)-1)): self.ups.append(nn.ConvTranspose2d(dims[i+1], dims[i], 4, 2, 1)) self.ups.append(Block(dims[i]*2, dims[i], dim)) self.final = nn.Conv2d(dim, out_channels, 1) def forward(self, x, t): t = self.time_mlp(t) hs = [] # 下采样 for block in self.downs: x = block(x, t) hs.append(x) x = nn.functional.avg_pool2d(x, 2) # 中间层 x = self.mid(x, t) # 上采样 for i in range(0, len(self.ups), 2): x = self.ups[i](x) skip = hs.pop() x = torch.cat([x, skip], dim=1) x = self.ups[i+1](x, t) return self.final(x)4. 训练流程实现
DDPM的训练目标是最小化预测噪声与真实噪声之间的L2距离。
4.1 损失函数定义
def p_losses(denoise_model, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) x_noisy = q_sample(x_start, t, noise) predicted_noise = denoise_model(x_noisy, t) return torch.mean((noise - predicted_noise)**2)4.2 训练循环
def train(model, dataloader, epochs=100, lr=1e-3): optimizer = torch.optim.Adam(model.parameters(), lr=lr) model.train() for epoch in range(epochs): total_loss = 0 for batch, _ in dataloader: batch = batch.to(device) # 随机采样时间步 t = torch.randint(0, T, (batch.size(0),), device=device) # 计算损失 loss = p_losses(model, batch, t) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f}") return model5. 采样生成图像
训练完成后,我们可以通过逐步去噪从随机噪声生成新图像。
5.1 单步采样
@torch.no_grad() def p_sample(model, x, t, t_index): betas_t = betas[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_recip_alphas_t = torch.sqrt(1.0 / alphas[t]).reshape(-1, 1, 1, 1) # 预测噪声 pred_noise = model(x, t) # 计算均值 model_mean = sqrt_recip_alphas_t * (x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t) if t_index == 0: return model_mean else: posterior_variance_t = (1 - alphas_cumprod[t-1]) / (1 - alphas_cumprod[t]) * betas[t] noise = torch.randn_like(x) return model_mean + torch.sqrt(posterior_variance_t).reshape(-1, 1, 1, 1) * noise5.2 完整采样流程
@torch.no_grad() def p_sample_loop(model, shape): # 从随机噪声开始 img = torch.randn(shape, device=device) imgs = [] for i in reversed(range(0, T)): t = torch.full((shape[0],), i, device=device, dtype=torch.long) img = p_sample(model, img, t, i) if i % (T//10) == 0 or i == T-1: imgs.append(img.cpu()) return imgs6. 实战演示与结果分析
让我们在MNIST数据集上训练模型并观察生成效果。
6.1 数据准备
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=128, shuffle=True)6.2 模型训练
model = UNet().to(device) trained_model = train(model, dataloader, epochs=20)6.3 生成新图像
sample_size = 16 generated_images = p_sample_loop(trained_model, (sample_size, 1, 28, 28)) # 可视化生成过程 plt.figure(figsize=(15, 15)) for i in range(len(generated_images)): plt.subplot(1, len(generated_images), i+1) plt.imshow(generated_images[i][0].squeeze(), cmap='gray') plt.title(f"Step {i*(T//len(generated_images))}") plt.axis('off') plt.show()通过这个完整实现,我们不仅理解了DDPM的核心原理,还获得了可以实际运行的代码。虽然我们的示例基于简单的MNIST数据集,但同样的架构经过适当调整可以扩展到更复杂的图像生成任务。
