当前位置: 首页 > news >正文

VAE In JAX【个人记录向】

和上一篇 SAC In JAX 一样,我们用 JAX 实现 VAE,配置一样,只需要安装符合版本的 torchvision 即可,实现中提供了 tensorboard 记录以及最后的可视化展示,测试集即为最经典的 MNIST,代码如下:

import jax
import flax
import math
import optax
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from datetime import datetime
from torchvision import datasets
from flax.training import train_state
from matplotlib import pyplot as plt
from flax.training.train_state import TrainState
from stable_baselines3.common.logger import configuredef get_data():train_dataset = datasets.MNIST(root='./data', train=True, download=True)test_dataset = datasets.MNIST(root='./data', train=False)img_train = train_dataset.data.numpy().reshape(-1, 784)train_x_min = np.min(img_train, axis=0)train_x_max = np.max(img_train, axis=0)train_x = (img_train - train_x_min) / (train_x_max - train_x_min + 1e-7)train_y = train_dataset.targets.numpy()img_test = test_dataset.data.numpy().reshape(-1, 784)test_x = (img_test - train_x_min) / (train_x_max - train_x_min + 1e-7)test_y = test_dataset.targets.numpy()N = train_x.shape[0]M = test_x.shape[0]return jnp.asarray(train_x), jnp.asarray(train_y), jnp.asarray(test_x), jnp.asarray(test_y), N, Mclass VAE_encoder(nn.Module):hidden_dim: intlatent_dim: int@nn.compactdef __call__(self, x):x = nn.Dense(self.hidden_dim)(x)encode = nn.relu(x)mu = nn.Dense(self.latent_dim)(encode)log_sig = nn.Dense(self.latent_dim)(encode)return mu, log_sigclass VAE_decoder(nn.Module):output_dim: inthidden_dim: int@nn.compactdef __call__(self, latent):x = nn.Dense(self.hidden_dim)(latent)x = nn.relu(x)x = nn.Dense(self.output_dim)(x)x = nn.sigmoid(x)return xclass VAE:def __init__(self, input_dim, encoder_lr, decoder_lr, epochs, batch_size, logger, key):self.input_dim = input_dimself.encoder_lr, self.encoder_lr = encoder_lr, decoder_lrself.epochs = epochsself.batch_size = batch_sizeself.logger = loggerself.key = keyself.hidden_dim = 400self.latent_dim = 48self.encoder = VAE_encoder(self.hidden_dim, self.latent_dim)self.decoder = VAE_decoder(self.input_dim, self.hidden_dim)self.key, encoder_key, decoder_key = jax.random.split(self.key, 3)encoder_params = self.encoder.init(encoder_key, jnp.ones((self.batch_size, self.input_dim)))['params']decoder_params = self.decoder.init(decoder_key, jnp.ones((self.batch_size, self.latent_dim)))['params']encoder_optx = optax.adam(encoder_lr)decoder_optx = optax.adam(decoder_lr)self.encoder_state = TrainState.create(apply_fn=self.encoder.apply, params=encoder_params, tx=encoder_optx)self.decoder_state = TrainState.create(apply_fn=self.decoder.apply, params=decoder_params, tx=decoder_optx)@staticmethod@jax.jitdef forward(x, encoder_state, decoder_state, now_key):mu, log_std = encoder_state.apply_fn({"params": encoder_state.params}, x)now_key, eps_key = jax.random.split(now_key, 2)eps = jax.random.normal(eps_key, shape=mu.shape)latent = mu + eps * jnp.exp(log_std * 0.5)x_ = decoder_state.apply_fn({"params": decoder_state.params}, latent)return x_, now_key@staticmethod@jax.jitdef train_step(data, encoder_state, decoder_state, key):def loss_fn(encoder_param, decoder_param, encoder_state, decoder_state, now_key):mu, log_std = encoder_state.apply_fn({"params": encoder_param}, data)now_key, eps_key = jax.random.split(now_key, 2)eps = jax.random.normal(eps_key, shape=mu.shape)latent = mu + eps * jnp.exp(log_std * 0.5)x_ = decoder_state.apply_fn({"params": decoder_param}, latent)construction_loss = jnp.sum(jnp.sum(jnp.square(x_ - data), axis=1))commitment_loss = -0.5 * jnp.sum(1 + log_std - mu ** 2 - jnp.exp(log_std))loss = construction_loss + commitment_lossreturn loss, (construction_loss, commitment_loss, now_key)(loss, (construction, commitment_loss, key)), grads = jax.value_and_grad(loss_fn, has_aux=True, argnums=(0, 1))(encoder_state.params, decoder_state.params, encoder_state, decoder_state, key)encoder_state = encoder_state.apply_gradients(grads=grads[0])decoder_state = decoder_state.apply_gradients(grads=grads[1])return encoder_state, decoder_state, construction, commitment_loss, keydef train(self, train_x, N):for epoch in range(self.epochs):self.key, permutation_key = jax.random.split(self.key, 2)shuffled_indices = jax.random.permutation(permutation_key, N)now_data = train_x[shuffled_indices, :]# now_data = train_xtot_construction_loss, tot_commitment_loss = 0, 0for i in range(0, N, self.batch_size):batch_x = now_data[i: i + self.batch_size]self.encoder_state, self.decoder_state, construction_loss, commitment_loss, self.key = VAE.train_step(batch_x, self.encoder_state, self.decoder_state, self.key)tot_construction_loss += construction_losstot_commitment_loss += commitment_lossnow = datetime.now()time_str = now.strftime("%Y-%m-%d %H:%M:%S")print(f"Epoch {epoch + 1}, Construction_loss: {tot_construction_loss / N:.4f}, Commitment_loss: {tot_commitment_loss / N:.4f}! Time: {time_str}.")self.logger.record("Construction_loss", float(tot_construction_loss) / N)self.logger.record("Commitment_loss", float(tot_commitment_loss) / N)self.logger.dump(step=epoch + 1)def plot(test_x, test_y, VAE_model):original_images = []for digit in range(10):original_image = [test_x[i] for i in range(len(test_x)) if test_y[i] == digit][111]original_images.append(jnp.asarray(original_image.reshape(784)))input = jnp.stack(original_images, axis=0)input = jnp.asarray(input, dtype=jnp.float32)output, VAE_model.key = VAE_model.forward(input, VAE_model.encoder_state, VAE_model.decoder_state, VAE_model.key)# output, _, __ = VAE_model(input_tensor)generated_images = list(np.array(jax.lax.stop_gradient(output)))fig, axes = plt.subplots(2, 10, figsize=(15, 3))for i in range(10):axes[0, i].imshow(original_images[i].reshape(28, 28), cmap='gray')axes[0, i].set_title(f"Original {i}")axes[0, i].axis('off')axes[1, i].imshow(generated_images[i].reshape(28, 28), cmap='gray')axes[1, i].set_title(f"Generated {i}")axes[1, i].axis('off')plt.tight_layout()plt.savefig('result.png')plt.show()def main():start_time = datetime.now().strftime('%Y%m%d_%H%M%S')log_path = f"logs/VAEtest_{start_time}/"logger = configure(log_path, ["tensorboard"])train_x, train_y, test_x, test_y, N, M = get_data()key = jax.random.PRNGKey(41)VAEmodel = VAE(input_dim=784, encoder_lr=0.001, decoder_lr=0.001, epochs=30, batch_size=128, logger=logger, key=key)VAEmodel.train(train_x, N)plot(test_x, test_y, VAEmodel)if __name__ == '__main__':main()

实验结果

重建误差稳定下降:

image

在测试集中随机抽取(并非随机)每个数字各一个样例,展示重建后的图片,其实效果不太符合预期,应该是要调参,但是我懒得调了:

image

http://www.jsqmd.com/news/114/

相关文章:

  • BLE蓝牙配网双模式实操:STA+SoftAP技术原理与避坑指南
  • 【小白也能懂】PyTorch 里的 0.5 到底是干啥的?——一次把 Normalize 讲透! - 教程
  • 第58天:RCE代码amp;命令执行amp;过滤绕过amp;异或无字符amp;无回显方案amp;黑白盒挖掘
  • 057-Web攻防-SSRFDemo源码Gopher项目等
  • 060-WEB攻防-PHP反序列化POP链构造魔术方法流程漏洞触发条件属性修改
  • 059-Web攻防-XXE安全DTD实体复现源码等
  • 061-WEB攻防-PHP反序列化原生类TIPSCVE绕过漏洞属性类型特征
  • 051-Web攻防-文件安全目录安全测试源码等
  • Dilworth定理及其在算法题中的应用
  • error: xxxxx does not have a commit checked out
  • 049-WEB攻防-文件上传存储安全OSS对象分站解析安全解码还原目录执行
  • 云原生周刊:MetalBear 融资、Chaos Mesh 漏洞、Dapr 1.16 与 AI 平台新趋势
  • AI一周资讯 250913-250919
  • 045-WEB攻防-PHP应用SQL二次注入堆叠执行DNS带外功能点黑白盒条件-cnblog
  • linux 命令语句
  • 用 Kotlin 实现英文数字验证码识别
  • UM2003A 一款 200 ~ 960MHz ASK/OOK +18dBm 发射功率的单发射
  • 达芬奇(DaVinci Reslove)字体文件 bugb标签
  • 语音芯片怎样挑选?语音芯片关键选型要点?
  • KingbaseES Schema权限及空间限额
  • UM2003A 一款 200 ~ 960MHz ASK/OOK +18dBm 发射功率的单发射芯片
  • HTTP库开发实战:核心库与httpplus扩展库示例解析
  • QMT交易系统向服务器同步订单丢失问题排查
  • 笔记1
  • 用 Python 和 Tesseract 实现英文数字验证码识别
  • 深入解析:上门按摩平台 “0 抽成 + 无底薪” 双模式拆解:如何让技师主动创收?
  • 实用指南:OSPF特殊区域、路由汇总及其他特性
  • 禅道以及bug
  • SUB-1G 无线收发芯片 DP10RF001 低功耗 (G) FSK/OOK 智能门锁,资产追踪、无线监控
  • 中电金信 :MCP在智能体应用中的挑战与对策