当前位置: 首页 > news >正文

别再死磕公式了!用Python和PyTorch手把手复现DDPM图像去噪(附完整代码)

从零构建DDPM:Python与PyTorch实战图像去噪

在计算机视觉领域,扩散模型正迅速成为生成高质量图像的主流方法。本文将带您从零开始,使用PyTorch框架完整实现一个基础的Denoising Diffusion Probabilistic Model(DDPM),无需深入复杂的数学推导,通过代码直观理解这一强大模型的工作原理。

1. 扩散模型基础概念

扩散模型的核心思想是通过逐步添加噪声破坏图像,再学习逆向去噪过程。想象一下把一杯清水慢慢滴入墨水的过程——扩散模型的正向过程就如同这个"污染"过程,而逆向过程则是神奇的"净化"操作。

与传统GAN或VAE不同,DDPM具有几个独特优势:

  • 训练稳定性:不依赖对抗训练,避免了模式坍塌问题
  • 生成质量:逐步细化生成过程,能产生更自然的高频细节
  • 理论优雅:基于热力学的非平衡统计物理基础

在技术实现层面,DDPM主要包含两个关键阶段:

  1. 前向扩散过程(Fixed Markov Chain):逐步向数据添加高斯噪声
  2. 逆向去噪过程(Learned Transition):训练神经网络逐步去噪
# 基础配置 import torch import torch.nn as nn import numpy as np from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. 前向扩散过程实现

前向过程定义为马尔可夫链,逐步将数据转化为各向同性高斯分布。关键在于设计合理的噪声调度(noise schedule),控制不同时间步的噪声添加量。

2.1 噪声调度设计

我们采用线性噪声调度,定义从β₁=1e-4到β_T=0.02的线性增长序列:

def linear_beta_schedule(timesteps, start=1e-4, end=0.02): return torch.linspace(start, end, timesteps) T = 1000 # 总时间步数 betas = linear_beta_schedule(T) alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) # α的连乘积 sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

2.2 单步扩散实现

给定原始图像x₀和时间步t,计算加噪后的图像x_t:

def q_sample(x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) sqrt_alpha_cumprod_t = sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise

可视化不同时间步的加噪效果:

def plot_diffusion_process(image, num_steps=5): plt.figure(figsize=(15, 3)) plt.subplot(1, num_steps+1, 1) plt.imshow(image.squeeze(), cmap='gray') plt.title("Original") plt.axis('off') for i in range(1, num_steps+1): t = torch.tensor([i*(T//num_steps)-1]) noisy_image = q_sample(image, t) plt.subplot(1, num_steps+1, i+1) plt.imshow(noisy_image.squeeze().cpu().numpy(), cmap='gray') plt.title(f"Step {t.item()+1}") plt.axis('off') plt.show()

3. 逆向去噪模型构建

逆向过程的核心是训练一个噪声预测网络。我们采用改进的U-Net架构,包含下采样和上采样路径,并加入时间步嵌入。

3.1 时间步嵌入

将离散时间步转换为连续向量表示:

class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, t): device = t.device half_dim = self.dim // 2 embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = t[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings

3.2 基础残差块

class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.act = nn.SiLU() self.bn = nn.BatchNorm2d(out_ch) def forward(self, x, t): h = self.bn(self.act(self.conv1(x))) time_emb = self.act(self.time_mlp(t)) h = h + time_emb.reshape(-1, h.shape[1], 1, 1) return self.act(self.conv2(h))

3.3 完整U-Net实现

class UNet(nn.Module): def __init__(self, in_channels=1, out_channels=1, dim=32, dim_mults=(1, 2, 4, 8)): super().__init__() self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim) ) dims = [in_channels] + [dim * m for m in dim_mults] self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) # 下采样路径 for i in range(len(dims)-1): self.downs.append(Block(dims[i], dims[i+1], dim)) # 中间层 self.mid = Block(dims[-1], dims[-1], dim) # 上采样路径 for i in reversed(range(len(dims)-1)): self.ups.append(nn.ConvTranspose2d(dims[i+1], dims[i], 4, 2, 1)) self.ups.append(Block(dims[i]*2, dims[i], dim)) self.final = nn.Conv2d(dim, out_channels, 1) def forward(self, x, t): t = self.time_mlp(t) hs = [] # 下采样 for block in self.downs: x = block(x, t) hs.append(x) x = nn.functional.avg_pool2d(x, 2) # 中间层 x = self.mid(x, t) # 上采样 for i in range(0, len(self.ups), 2): x = self.ups[i](x) skip = hs.pop() x = torch.cat([x, skip], dim=1) x = self.ups[i+1](x, t) return self.final(x)

4. 训练流程实现

DDPM的训练目标是最小化预测噪声与真实噪声之间的L2距离。

4.1 损失函数定义

def p_losses(denoise_model, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) x_noisy = q_sample(x_start, t, noise) predicted_noise = denoise_model(x_noisy, t) return torch.mean((noise - predicted_noise)**2)

4.2 训练循环

def train(model, dataloader, epochs=100, lr=1e-3): optimizer = torch.optim.Adam(model.parameters(), lr=lr) model.train() for epoch in range(epochs): total_loss = 0 for batch, _ in dataloader: batch = batch.to(device) # 随机采样时间步 t = torch.randint(0, T, (batch.size(0),), device=device) # 计算损失 loss = p_losses(model, batch, t) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f}") return model

5. 采样生成图像

训练完成后,我们可以通过逐步去噪从随机噪声生成新图像。

5.1 单步采样

@torch.no_grad() def p_sample(model, x, t, t_index): betas_t = betas[t].reshape(-1, 1, 1, 1) sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1) sqrt_recip_alphas_t = torch.sqrt(1.0 / alphas[t]).reshape(-1, 1, 1, 1) # 预测噪声 pred_noise = model(x, t) # 计算均值 model_mean = sqrt_recip_alphas_t * (x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t) if t_index == 0: return model_mean else: posterior_variance_t = (1 - alphas_cumprod[t-1]) / (1 - alphas_cumprod[t]) * betas[t] noise = torch.randn_like(x) return model_mean + torch.sqrt(posterior_variance_t).reshape(-1, 1, 1, 1) * noise

5.2 完整采样流程

@torch.no_grad() def p_sample_loop(model, shape): # 从随机噪声开始 img = torch.randn(shape, device=device) imgs = [] for i in reversed(range(0, T)): t = torch.full((shape[0],), i, device=device, dtype=torch.long) img = p_sample(model, img, t, i) if i % (T//10) == 0 or i == T-1: imgs.append(img.cpu()) return imgs

6. 实战演示与结果分析

让我们在MNIST数据集上训练模型并观察生成效果。

6.1 数据准备

transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

6.2 模型训练

model = UNet().to(device) trained_model = train(model, dataloader, epochs=20)

6.3 生成新图像

sample_size = 16 generated_images = p_sample_loop(trained_model, (sample_size, 1, 28, 28)) # 可视化生成过程 plt.figure(figsize=(15, 15)) for i in range(len(generated_images)): plt.subplot(1, len(generated_images), i+1) plt.imshow(generated_images[i][0].squeeze(), cmap='gray') plt.title(f"Step {i*(T//len(generated_images))}") plt.axis('off') plt.show()

通过这个完整实现,我们不仅理解了DDPM的核心原理,还获得了可以实际运行的代码。虽然我们的示例基于简单的MNIST数据集,但同样的架构经过适当调整可以扩展到更复杂的图像生成任务。

http://www.jsqmd.com/news/875884/

相关文章:

  • 腾讯点选VMP环境补全与Hook实战:构建可信浏览器沙盒
  • 如何选择性价比高的全屋定制供应商,源头全屋定制厂家攻略揭秘 - mypinpai
  • NVIDIA Profile Inspector终极指南:5步解锁显卡隐藏功能,轻松提升游戏性能30%
  • ContextMenuManager:三步彻底掌控Windows右键菜单的终极免费工具
  • 2026年目前可靠的邓州室内装修品牌哪家好 - 品牌排行榜
  • 分子动力学模拟揭秘:非晶材料断裂韧性的原子尺度起源
  • GHelper架构设计与风扇控制技术深度解析:构建华硕笔记本轻量级系统优化解决方案
  • 企业级MCP Server OAuth接入实战:租户隔离与IDP适配
  • 基于局部交叉对称色散关系的弦振幅参数化表示与数值引导
  • 性价比高的CPE流延高透膜设备先进的加工厂盘点,哪家比较靠谱 - mypinpai
  • ContextMenuManager:让Windows右键菜单从此清爽高效
  • ContextMenuManager:重新定义Windows右键菜单的交互设计思维
  • 2025-2026年产业园区公司联系电话推荐:精选资源与联系指南 - 品牌推荐
  • 广东白云学院登录接口逆向实战:DES-CBC动态密钥与高校系统反爬细节
  • 2025-2026年王雯律师电话查询:委托前请核实执业资质与收费标准 - 品牌推荐
  • Windows控制台程序逆向入门:从CMP指令看程序逻辑解构
  • 伴随方法与自动微分:高效梯度计算的核心原理与工程实践
  • Java并发工具类CountDownLatch与CyclicBarrier
  • Unity Android读取SD卡图片的5种实战方案与选型指南
  • CVE-2022-40684深度解析:飞塔防火墙session token泄露原理与实战利用
  • 保姆级教程:用perf stat排查Linux服务器性能瓶颈(附实战命令)
  • ContextMenuManager:Windows右键菜单终极管理指南,让你的电脑效率翻倍
  • 5大核心功能揭秘:BetterGI原神自动化工具完整使用指南
  • 非Root安卓设备上使用Frida Gadget实现应用层Hook
  • 2025-2026年北京老房改造装修公司推荐:五大口碑评测老房水电改造性价比高价格 - 品牌推荐
  • 文本归一化:提升朴素贝叶斯在钓鱼短信检测中的准确率
  • 量子机器学习在日志异常检测中的实践:编码、电路设计与性能评估
  • 1-4 直流电与交流电
  • 新电脑到手别急着用!Win11必做的3个存储优化设置(磁盘分区+改默认路径+软件安装避坑)
  • Hugging Face模型供应链实证分析:文档、依赖与许可证风险