当前位置: 首页 > news >正文

[DL_Net从入门到入土] 变分自编码器 VAE

[DL_Net从入门到入土] 变分自编码器 VAE

📢个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [DL_Net从入门到入土] 变分自编码器 VAE
    • 📢个人导航
    • 📖参考资料
    • 🌱背景
    • ⚙️架构(公式)
        • 1. Encoder
        • 2. 隐变量z
        • 3. Decoder
        • 4. 重参数化技巧
        • 5. 损失函数
    • 👍优点/创新点
        • 1. 隐空间更加连续
        • 2. 可以生成新样本
        • 3. 引入概率建模思想
    • 👎缺点
        • 1. 生成图像可能比较模糊
        • 2. KL 散度可能导致重建质量下降
        • 3. 可能出现 posterior collapse
    • 💻代码实现

📖参考资料

Auto-Encoding Variational Bayes.

🌱背景

VAE: Variational AutoEncoder: 变分自编码器

传统 AE 有一个问题: 没有规定隐变量z zz应该长什么样, 不要求隐空间具有良好的结构

假设 AE 把三张图片编码成下面几个点:

  • 数字 1 → z = [1.0, 1.0]
  • 数字 7 → z = [5.0, 5.0]
  • 数字 9 → z = [20.0, 20.0]

这些点本身可以被 AE 的 Decoder 还原出来

但是问题在于:如果我们随便采样一个点,比如z = [ 10.0 , 10.0 ] z=[10.0,10.0]z=[10.0,10.0],Decoder 不能生成一个合理的数字
因为 AE 训练时只见过 Encoder 生成出来的那些z zz, 故而没有要求整个隐空间都连续、平滑、有意义

VAE 想让隐空间变得更加规整

普通 AE 是把输入编码成一个确定的点:x → z x→zxz
VAE 会把输入编码成一个概率分布:x → μ , σ x→μ,σxμ,σ

约束这些分布都接近标准正态分布N ( 0 , I ) \mathcal{N}(0, I)N(0,I), 那么隐空间就会变得更加连续、规整

对于一个输入x xx,Encoder 不再输出一个固定的z zz,而是输出一个分布的参数:

  • μ \muμ:均值
  • σ \sigmaσ:标准差

然后再从这个分布中采样一个z zz:
z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)zN(μ,σ2)
最后用 Decoder 把z zz还原成x ^ \hat{x}x^:
x ^ = g ϕ ( z ) \hat{x} = g_\phi(z)x^=gϕ(z)

⚙️架构(公式)

输入 x ↓ Encoder ↓ μ, logσ² !!! ↓ 重参数化采样 !!! ↓ 隐变量 z !!! ↓ Decoder ↓ 重建结果 x'
1. Encoder

AE 的 Encoder 是:
z = f θ ( x ) z = f_\theta(x)z=fθ(x)
VAE 的 Encoder是:
μ , log ⁡ σ 2 = f θ ( x ) \mu, \log \sigma^2 = f_\theta(x)μ,logσ2=fθ(x)

2. 隐变量z

VAE 假设隐变量z zz服从一个高斯分布:
z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)zN(μ,σ2)
-> 对于同一个输入x xx, VAE 不认为它只能对应一个固定的z zz, 而是认为它可以对应一片附近区域

3. Decoder

Decoder 的作用和 AE 类似:
x ^ = g ϕ ( z ) \hat{x} = g_\phi(z)x^=gϕ(z)
也就是把隐变量z zz解码成重建结果x ^ \hat{x}x^

4. 重参数化技巧

采样操作本身是随机的,普通采样无法直接反向传播

z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)zN(μ,σ2), 这个采样过程会导致梯度不好传回 Encoder

原本是:
z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)zN(μ,σ2)
改写成:
z = μ + σ ⊙ ϵ z = \mu + \sigma \odot \epsilonz=μ+σϵ

  • ϵ \epsilonϵ:从标准正态分布中随机采样的噪声ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I)ϵN(0,I)
  • μ \muμ:Encoder 输出的均值
  • σ \sigmaσ:Encoder 输出的标准差
  • ⊙ \odot:逐元素相乘

-> 随机性被转移到了ϵ \epsilonϵ上, 那么μ \muμσ \sigmaσ就可以参与梯度计算了

5. 损失函数

由两部分组成:
L V A E = L r e c o n + L K L \mathcal{L}_{VAE} = \mathcal{L}_{recon} + \mathcal{L}_{KL}LVAE=Lrecon+LKL

目标作用损失
重建好x ^ \hat{x}x^尽可能接近x xxL r e c o n \mathcal{L}_{recon}Lrecon
隐空间规整z zz的分布接近标准正态分布L K L \mathcal{L}_{KL}LKL

L r e c o n = M S E ( x , x ^ ) ∣ B C E ( x , x ^ ) \mathcal{L}_{recon}=MSE(x,\hat{x}) | BCE(x,\hat{x})Lrecon=MSE(x,x^)BCE(x,x^)

像 MNIST 这类像素范围在[ 0 , 1 ] [0,1][0,1]的图片,常用 BCE

L K L = − 1 2 ∑ ( 1 + log ⁡ σ 2 − μ 2 − σ 2 ) \mathcal{L}_{KL}=-\frac{1}{2}\sum\left(1 + \log \sigma^2 - \mu^2 - \sigma^2\right)LKL=21(1+logσ2μ2σ2)

👍优点/创新点

1. 隐空间更加连续

这样在两个样本之间插值时,VAE 更容易得到合理的过渡结果

2. 可以生成新样本

训练完成后,我们可以直接随机采样一个z zz
z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I)zN(0,I)
然后送入 Decoder,x ^ = g ϕ ( z ) \hat{x} = g_\phi(z)x^=gϕ(z), 从而生成新的数据

这就是 VAE 能够作为生成模型的关键原因

3. 引入概率建模思想

AE 是确定性模型:
x → z → x ^ x \rightarrow z \rightarrow \hat{x}xzx^
VAE 是概率生成模型:
x → q θ ( z ∣ x ) → z → p ϕ ( x ∣ z ) x \rightarrow q_\theta(z|x) \rightarrow z \rightarrow p_\phi(x|z)xqθ(zx)zpϕ(xz)
-> VAE 不只是学习一个编码结果,而是学习一个潜在变量的概率分布

👎缺点

1. 生成图像可能比较模糊

VAE 经常使用 MSE 或 BCE 作为重建损失, 这些损失函数倾向于让模型生成“平均化”的结果

对于图像来说,就可能导致生成图片不够锐利,看起来比较模糊

这也是为什么很多情况下,GAN 生成的图片会比 VAE 更清晰

如果只看生成图片的质量,传统 VAE 往往不如 GAN 或 Diffusion Model

2. KL 散度可能导致重建质量下降

VAE 不仅要重建输入,还要让隐变量分布接近标准正态分布
-. 这两个目标有时会互相冲突

如果 KL 约束太强,模型为了让z zz更接近标准正态分布,可能会牺牲一部分重建效果

3. 可能出现 posterior collapse

posterior collapse: Decoder 太强了,导致它几乎不依赖隐变量z zz

模型虽然形式上有 Encoder 和隐变量z zz,但 Decoder 实际上可能忽略了z zz
-> 此时隐变量没有学到有用信息

💻代码实现

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassVAE(nn.Module):def__init__(self,input_dim=784,hidden_dim=128,latent_dim=32):super(VAE,self).__init__()# Encoder: x -> hiddenself.encoder=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.ReLU())# 分别输出 mu 和 logvarself.fc_mu=nn.Linear(hidden_dim,latent_dim)self.fc_logvar=nn.Linear(hidden_dim,latent_dim)# Decoder: z -> x_hatself.decoder=nn.Sequential(nn.Linear(latent_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,input_dim),nn.Sigmoid())def_encode(self,x):h=self.encoder(x)mu=self.fc_mu(h)logvar=self.fc_logvar(h)returnmu,logvardef_reparameterize(self,mu,logvar):# logvar = log(sigma^2)# std = sigmastd=torch.exp(0.5*logvar)# eps ~ N(0, I)eps=torch.randn_like(std)# z = mu + sigma * epsz=mu+std*epsreturnzdef_decode(self,z):x_hat=self.decoder(z)returnx_hatdefforward(self,x):mu,logvar=self._encode(x)z=self._reparameterize(mu,logvar)x_hat=self._decode(z)returnx_hat,mu,logvar

http://www.jsqmd.com/news/869523/

相关文章:

  • 如何用MusicFree插件构建你的跨平台音乐生态:从零开始的全流程指南
  • XUnity自动翻译器终极指南:轻松实现游戏多语言无障碍体验
  • 区块链+AI+边缘计算:构建可信、高效的糖尿病风险预测系统
  • RDP Wrapper技术架构深度解析:破解Windows远程桌面限制的完整方案
  • 从音乐囚徒到音乐主人:Unlock Music Electron 终极音乐解锁指南
  • Blender 3MF插件:打破3D打印数据孤岛的技术桥梁
  • 一文带你看懂多模态大模型的降维打击!
  • 用SigmaStudio Plus如何来开发ADAU1466(1)软硬件开发环境的搭建和正确性检测
  • 普通人能做的最新商机哪里找?集客大师告诉你!
  • 移民美国项目怎么选 专业服务助家庭规划 - 品牌排行榜
  • RK3588 下位机搜索不到问题排查
  • 嵌入式开发新范式:C与JavaScript混合编程架构与实践
  • 2026年6月PMP最后15天:放弃幻想,照抄这份极简计划
  • 2026年移民美国项目公司选择要点分析 - 品牌排行榜
  • 2026水果店加盟选哪家?从产品到服务的全方位对比分析 - 品牌排行榜
  • THINKPHP 8 + PHP 8.0 + 40+功能优化,多商户系统v4.0为“百亿GMV”铺路
  • 5步掌握Nexus Mods App:告别模组管理烦恼的开源神器
  • 测绘行业数据安全解决方案
  • 深入解析LiteOS-M内核队列:数据结构、算法与嵌入式通信优化
  • 京尚放大招!一口锅一个码,全程透明不忽悠
  • 代码段权限RWX
  • ARM CoreLink 系列 4.3 -- NI-700 Component and interface identifiers
  • AI经营报告项目——项目记录
  • 广东厨房收纳配件供应商推荐,图特股份等企业可提供定制服务
  • 跨平台线程池组件设计:从核心原理到C++实现详解
  • PyCharm无法引用本地扩展包问题的结解决方法
  • 踩坑记录:爬虫代理 403/超时问题的 5 层排查法
  • 微信小程序 宠物服务系统
  • asnumpy 昇腾版 NumPy:在 NPU 上跑你的科学计算代码
  • 外卖门店经营数据看板(Excel动态仪表板)