[DL_Net从入门到入土] 变分自编码器 VAE
[DL_Net从入门到入土] 变分自编码器 VAE
📢个人导航
知乎:https://www.zhihu.com/people/byzh_rc
CSDN:https://blog.csdn.net/qq_54636039
注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码
参考文章:各方资料
文章目录
- [DL_Net从入门到入土] 变分自编码器 VAE
- 📢个人导航
- 📖参考资料
- 🌱背景
- ⚙️架构(公式)
- 1. Encoder
- 2. 隐变量z
- 3. Decoder
- 4. 重参数化技巧
- 5. 损失函数
- 👍优点/创新点
- 1. 隐空间更加连续
- 2. 可以生成新样本
- 3. 引入概率建模思想
- 👎缺点
- 1. 生成图像可能比较模糊
- 2. KL 散度可能导致重建质量下降
- 3. 可能出现 posterior collapse
- 💻代码实现
📖参考资料
Auto-Encoding Variational Bayes.
🌱背景
VAE: Variational AutoEncoder: 变分自编码器
传统 AE 有一个问题: 没有规定隐变量z zz应该长什么样, 不要求隐空间具有良好的结构
假设 AE 把三张图片编码成下面几个点:
- 数字 1 → z = [1.0, 1.0]
- 数字 7 → z = [5.0, 5.0]
- 数字 9 → z = [20.0, 20.0]
这些点本身可以被 AE 的 Decoder 还原出来
但是问题在于:如果我们随便采样一个点,比如z = [ 10.0 , 10.0 ] z=[10.0,10.0]z=[10.0,10.0],Decoder 不能生成一个合理的数字
因为 AE 训练时只见过 Encoder 生成出来的那些z zz, 故而没有要求整个隐空间都连续、平滑、有意义
VAE 想让隐空间变得更加规整
普通 AE 是把输入编码成一个确定的点:x → z x→zx→z
而VAE 会把输入编码成一个概率分布:x → μ , σ x→μ,σx→μ,σ
约束这些分布都接近标准正态分布N ( 0 , I ) \mathcal{N}(0, I)N(0,I), 那么隐空间就会变得更加连续、规整
对于一个输入x xx,Encoder 不再输出一个固定的z zz,而是输出一个分布的参数:
- μ \muμ:均值
- σ \sigmaσ:标准差
然后再从这个分布中采样一个z zz:
z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)z∼N(μ,σ2)
最后用 Decoder 把z zz还原成x ^ \hat{x}x^:
x ^ = g ϕ ( z ) \hat{x} = g_\phi(z)x^=gϕ(z)
⚙️架构(公式)
输入 x ↓ Encoder ↓ μ, logσ² !!! ↓ 重参数化采样 !!! ↓ 隐变量 z !!! ↓ Decoder ↓ 重建结果 x'1. Encoder
AE 的 Encoder 是:
z = f θ ( x ) z = f_\theta(x)z=fθ(x)
VAE 的 Encoder是:
μ , log σ 2 = f θ ( x ) \mu, \log \sigma^2 = f_\theta(x)μ,logσ2=fθ(x)
2. 隐变量z
VAE 假设隐变量z zz服从一个高斯分布:
z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)z∼N(μ,σ2)
-> 对于同一个输入x xx, VAE 不认为它只能对应一个固定的z zz, 而是认为它可以对应一片附近区域
3. Decoder
Decoder 的作用和 AE 类似:
x ^ = g ϕ ( z ) \hat{x} = g_\phi(z)x^=gϕ(z)
也就是把隐变量z zz解码成重建结果x ^ \hat{x}x^
4. 重参数化技巧
采样操作本身是随机的,普通采样无法直接反向传播
如z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)z∼N(μ,σ2), 这个采样过程会导致梯度不好传回 Encoder
原本是:
z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2)z∼N(μ,σ2)
改写成:
z = μ + σ ⊙ ϵ z = \mu + \sigma \odot \epsilonz=μ+σ⊙ϵ
- ϵ \epsilonϵ:从标准正态分布中随机采样的噪声ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I)ϵ∼N(0,I)
- μ \muμ:Encoder 输出的均值
- σ \sigmaσ:Encoder 输出的标准差
- ⊙ \odot⊙:逐元素相乘
-> 随机性被转移到了ϵ \epsilonϵ上, 那么μ \muμ和σ \sigmaσ就可以参与梯度计算了
5. 损失函数
由两部分组成:
L V A E = L r e c o n + L K L \mathcal{L}_{VAE} = \mathcal{L}_{recon} + \mathcal{L}_{KL}LVAE=Lrecon+LKL
| 目标 | 作用 | 损失 |
|---|---|---|
| 重建好 | 让x ^ \hat{x}x^尽可能接近x xx | L r e c o n \mathcal{L}_{recon}Lrecon |
| 隐空间规整 | 让z zz的分布接近标准正态分布 | L K L \mathcal{L}_{KL}LKL |
L r e c o n = M S E ( x , x ^ ) ∣ B C E ( x , x ^ ) \mathcal{L}_{recon}=MSE(x,\hat{x}) | BCE(x,\hat{x})Lrecon=MSE(x,x^)∣BCE(x,x^)
像 MNIST 这类像素范围在[ 0 , 1 ] [0,1][0,1]的图片,常用 BCE
L K L = − 1 2 ∑ ( 1 + log σ 2 − μ 2 − σ 2 ) \mathcal{L}_{KL}=-\frac{1}{2}\sum\left(1 + \log \sigma^2 - \mu^2 - \sigma^2\right)LKL=−21∑(1+logσ2−μ2−σ2)
👍优点/创新点
1. 隐空间更加连续
这样在两个样本之间插值时,VAE 更容易得到合理的过渡结果
2. 可以生成新样本
训练完成后,我们可以直接随机采样一个z zz:
z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I)z∼N(0,I)
然后送入 Decoder,x ^ = g ϕ ( z ) \hat{x} = g_\phi(z)x^=gϕ(z), 从而生成新的数据
这就是 VAE 能够作为生成模型的关键原因
3. 引入概率建模思想
AE 是确定性模型:
x → z → x ^ x \rightarrow z \rightarrow \hat{x}x→z→x^
VAE 是概率生成模型:
x → q θ ( z ∣ x ) → z → p ϕ ( x ∣ z ) x \rightarrow q_\theta(z|x) \rightarrow z \rightarrow p_\phi(x|z)x→qθ(z∣x)→z→pϕ(x∣z)
-> VAE 不只是学习一个编码结果,而是学习一个潜在变量的概率分布
👎缺点
1. 生成图像可能比较模糊
VAE 经常使用 MSE 或 BCE 作为重建损失, 这些损失函数倾向于让模型生成“平均化”的结果
对于图像来说,就可能导致生成图片不够锐利,看起来比较模糊
这也是为什么很多情况下,GAN 生成的图片会比 VAE 更清晰
如果只看生成图片的质量,传统 VAE 往往不如 GAN 或 Diffusion Model
2. KL 散度可能导致重建质量下降
VAE 不仅要重建输入,还要让隐变量分布接近标准正态分布
-. 这两个目标有时会互相冲突
如果 KL 约束太强,模型为了让z zz更接近标准正态分布,可能会牺牲一部分重建效果
3. 可能出现 posterior collapse
posterior collapse: Decoder 太强了,导致它几乎不依赖隐变量z zz
模型虽然形式上有 Encoder 和隐变量z zz,但 Decoder 实际上可能忽略了z zz
-> 此时隐变量没有学到有用信息
💻代码实现
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassVAE(nn.Module):def__init__(self,input_dim=784,hidden_dim=128,latent_dim=32):super(VAE,self).__init__()# Encoder: x -> hiddenself.encoder=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.ReLU())# 分别输出 mu 和 logvarself.fc_mu=nn.Linear(hidden_dim,latent_dim)self.fc_logvar=nn.Linear(hidden_dim,latent_dim)# Decoder: z -> x_hatself.decoder=nn.Sequential(nn.Linear(latent_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,input_dim),nn.Sigmoid())def_encode(self,x):h=self.encoder(x)mu=self.fc_mu(h)logvar=self.fc_logvar(h)returnmu,logvardef_reparameterize(self,mu,logvar):# logvar = log(sigma^2)# std = sigmastd=torch.exp(0.5*logvar)# eps ~ N(0, I)eps=torch.randn_like(std)# z = mu + sigma * epsz=mu+std*epsreturnzdef_decode(self,z):x_hat=self.decoder(z)returnx_hatdefforward(self,x):mu,logvar=self._encode(x)z=self._reparameterize(mu,logvar)x_hat=self._decode(z)returnx_hat,mu,logvar