当前位置: 首页 > news >正文

别再被数学劝退!用PyTorch从零实现DDPM扩散模型(附完整代码)

用PyTorch实战DDPM:无需深究数学也能玩转扩散模型

当你在社交媒体上看到AI生成的艺术作品时,是否好奇过它们背后的技术原理?扩散模型(Diffusion Models)作为当前最热门的生成式AI技术之一,正以惊人的速度改变着内容创作的格局。本文将带你绕过复杂的数学推导,直接进入代码实践环节,用PyTorch从零构建一个完整的DDPM(Denoising Diffusion Probabilistic Models)模型。

1. 环境准备与数据加载

在开始之前,我们需要配置好开发环境。推荐使用Python 3.8+和PyTorch 1.12+版本:

pip install torch torchvision matplotlib tqdm

对于数据集,我们将使用经典的CIFAR-10,它包含60,000张32x32的彩色图像:

import torch from torchvision import datasets, transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = datasets.CIFAR10( root='./data', train=True, download=True, transform=transform ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=128, shuffle=True )

提示:如果你的GPU显存较小,可以将batch_size调整为64或32

2. DDPM核心组件实现

2.1 噪声调度器

扩散模型的核心在于如何合理地添加和去除噪声。我们需要定义一个噪声调度器来控制不同时间步的噪声强度:

import math def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) def cosine_beta_schedule(timesteps, s=0.008): steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) timesteps = 1000 betas = cosine_beta_schedule(timesteps) # 使用余弦调度器效果更好 # 预计算有用的值 alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

2.2 前向加噪过程

前向过程逐步将数据转换为高斯噪声,这个过程是固定的,不需要训练:

def q_sample(x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise def extract(a, t, x_shape): batch_size = t.shape[0] out = a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

3. 构建U-Net模型

U-Net是DDPM中用于预测噪声的核心网络结构。下面我们实现一个简化版的U-Net:

import torch.nn as nn import torch.nn.functional as F class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.norm = nn.GroupNorm(8, out_ch) def forward(self, x, t): h = self.conv1(x) h = self.norm(h) h = F.silu(h) time_emb = F.silu(self.time_mlp(t)) h = h + time_emb[:, :, None, None] h = self.conv2(h) h = self.norm(h) h = F.silu(h) return h class UNet(nn.Module): def __init__(self): super().__init__() self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(100), nn.Linear(100, 256), nn.SiLU(), nn.Linear(256, 256) ) self.down1 = Block(3, 64, 256) self.down2 = Block(64, 128, 256) self.down3 = Block(128, 256, 256) self.mid = Block(256, 256, 256) self.up1 = Block(512, 128, 256) self.up2 = Block(256, 64, 256) self.up3 = Block(128, 64, 256) self.out = nn.Conv2d(64, 3, 1) def forward(self, x, t): t = self.time_mlp(t) # 下采样 h1 = self.down1(x, t) h2 = self.down2(F.max_pool2d(h1, 2), t) h3 = self.down3(F.max_pool2d(h2, 2), t) # 中间层 h = self.mid(F.max_pool2d(h3, 2), t) # 上采样 h = F.interpolate(h, scale_factor=2, mode='nearest') h = self.up1(torch.cat([h, h3], dim=1), t) h = F.interpolate(h, scale_factor=2, mode='nearest') h = self.up2(torch.cat([h, h2], dim=1), t) h = F.interpolate(h, scale_factor=2, mode='nearest') h = self.up3(torch.cat([h, h1], dim=1), t) return self.out(h) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings

4. 训练与采样

4.1 训练循环

DDPM的训练目标是让网络学会预测添加到图像中的噪声:

def train(model, dataloader, optimizer, epochs, device): model.train() for epoch in range(epochs): for step, (images, _) in enumerate(dataloader): images = images.to(device) # 随机采样时间步 t = torch.randint(0, timesteps, (images.shape[0],), device=device).long() # 生成随机噪声 noise = torch.randn_like(images) # 前向加噪过程 noisy_images = q_sample(images, t, noise) # 预测噪声 predicted_noise = model(noisy_images, t) # 计算损失 loss = F.mse_loss(noise, predicted_noise) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 == 0: print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")

4.2 采样生成

训练完成后,我们可以使用模型从纯噪声开始逐步去噪生成新图像:

@torch.no_grad() def p_sample(model, x, t, t_index): betas_t = extract(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape) sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) # 计算模型均值 model_mean = sqrt_recip_alphas_t * ( x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t ) if t_index == 0: return model_mean else: posterior_variance_t = extract(posterior_variance, t, x.shape) noise = torch.randn_like(x) return model_mean + torch.sqrt(posterior_variance_t) * noise @torch.no_grad() def sample(model, image_size, batch_size=16, channels=3): device = next(model.parameters()).device # 从随机噪声开始 img = torch.randn((batch_size, channels, image_size, image_size), device=device) for i in reversed(range(0, timesteps)): t = torch.full((batch_size,), i, device=device, dtype=torch.long) img = p_sample(model, img, t, i) # 将图像从[-1,1]转换到[0,1] img = (img + 1) * 0.5 return img

5. 模型优化与技巧

在实际应用中,我们可以采用以下几种策略来提升DDPM的性能:

  1. 学习率调度:使用余弦退火学习率可以显著提升模型收敛速度
  2. 混合精度训练:通过FP16训练可以节省显存并加快训练速度
  3. EMA模型:使用指数移动平均的模型参数可以提高生成质量
  4. 渐进式训练:从低分辨率开始训练,逐步提高分辨率
# 示例:EMA模型实现 class EMA: def __init__(self, beta): super().__init__() self.beta = beta self.step = 0 def update_model_average(self, ema_model, current_model): for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): old_weight, new_weight = ema_params.data, current_params.data ema_params.data = self.update_average(old_weight, new_weight) def update_average(self, old, new): if old is None: return new return old * self.beta + (1 - self.beta) * new

在CIFAR-10数据集上训练约50个epoch后,你应该能够看到模型开始生成可识别的物体图像。虽然32x32的分辨率不高,但这个完整的实现已经包含了DDPM的所有关键组件。

http://www.jsqmd.com/news/849350/

相关文章:

  • 通过环境变量为hermesagent配置taotoken作为自定义模型服务提供方
  • CANN/asc-devkit 设置梯度输出类型
  • CANNBot torch-compile 快速入门
  • 2026河北钢制防火门多少钱一平米?甲乙丙级最新报价
  • CANN混元视频配置说明
  • 数据中心工频UPS哪家好?2026工频不间断电源/核磁用UPS电源生产厂家权威推荐 - 栗子测评
  • CTF中的音频隐写术实战:从‘兔耳’和‘调频收音机’两道Misc题,学会用Python脚本提取隐藏信息
  • HermesAgent工具连接Taotoken自定义模型提供方的完整流程
  • CANN Bench交叉熵损失算子评测
  • Matlab阶跃响应性能指标自动化计算:从原理到工程实践
  • 如何快速上手elec-ops-inspection:昇腾平台部署指南
  • Configor 自动重载功能深度解析:实现配置热更新的终极指南
  • CANN/hccl RDMA QP端口配置路径
  • 轨距调整片定制哪家好?2026年绝缘轨距块生产厂家优质供应商推荐指南:新建铁路配件领衔 - 栗子测评
  • 2026机房不间断电源生产厂家哪家好?深圳不间断电源生产厂家实力深度解析 - 栗子测评
  • cann/asc-devkit SetGradOutput接口
  • CANN ops-fft部署指南:生产环境中的配置、监控与故障排除
  • npc_gzip异常处理与调试手册:解决压缩器错误的10个实用技巧
  • Commit Mono版本管理指南:如何优雅地升级和回滚字体版本
  • 源头工厂直供:利成充气水池定制厂家,广东便携式宠物泳池、PVC 戏水玩具、水上充气浮排专业生产基地 - 栗子测评
  • 穿透算法黑箱:2026论文降AI率工具深度测评,早标网语义保真度99%
  • 橡胶垫板定制厂家推荐:新建铁路配件领衔,2026年口碑好的调高垫板批发厂家/轨道橡胶垫板生产厂家/精调件生产厂家盘点 - 栗子测评
  • Transformer架构解析:自注意力机制与LLM核心技术
  • CrossGeo:首个跨卫星-无人机-地面三重视角的6-DoF 3D重建与定位数据集详解
  • 【YOLO目标检测全栈实战】48 深入TensorRT加速:从28ms到6ms的C++推理实战
  • Seed-VC语音克隆指南:5分钟实现零样本实时语音转换的终极方案
  • ARM SPE Profiling Buffer机制与性能分析实践
  • 地空协同巡检新范式:elec-ops-inspection 3D空间建模技术
  • GIFT应用案例:从Web服务到移动应用的实际部署方案
  • USB/IP Windows:打破物理限制的USB设备网络共享终极方案