基于Transformer的自回归图像生成模型实现
1. 项目概述与背景
在计算机视觉领域,图像生成一直是一个极具挑战性的研究方向。传统的生成对抗网络(GAN)和变分自编码器(VAE)虽然取得了不错的效果,但都存在训练不稳定或生成质量有限的问题。近年来,基于Transformer的自回归模型在图像生成领域展现出强大的潜力。
本项目实现了一个基于Transformer的自回归图像生成模型,其核心思想是将图像分割为小块(patch),通过BSQ二值量化模块将每个patch编码为离散的token序列,然后使用Transformer模型对这些token进行自回归预测和生成。这种方法结合了离散表示的优势和Transformer强大的序列建模能力。
提示:自回归生成的核心特点是每个token的预测都依赖于之前生成的所有token,这与人类书写或绘画的过程非常相似。
2. 数据预处理流程
2.1 图像token化处理
数据预处理的第一步是将原始图像转换为token序列。这个过程依赖于预训练好的BSQPatchAutoEncoder模型:
python -m homework.tokenize checkpoints/YOUR_BSQPatchAutoEncoder.pth data/tokenized_train.pth data/train/*.jpg python -m homework.tokenize checkpoints/YOUR_BSQPatchAutoEncoder.pth data/tokenized_valid.pth data/valid/*.jpg这段代码会遍历指定目录下的所有图像文件,使用BSQPatchAutoEncoder将它们编码为token序列,并保存为.pth文件。关键参数说明:
patch_size=5:图像被分割为5×5的小块codebook_bits=10:每个patch被编码为10位二进制码,对应1024种可能的token
文件大小验证:
du -hs data/tokenized_train.pth对于典型的配置,生成的token文件大小约为76MB,具体取决于图像数量和分辨率。
2.2 数据格式解析
生成的token数据集具有以下结构:
- 每个样本是一个3维张量 (1, h, w)
- h和w取决于原始图像尺寸和patch大小
- 每个元素是0到1023之间的整数,代表对应patch的token ID
3. 模型架构设计
3.1 核心组件
模型的核心是AutoregressiveModel类,它继承自torch.nn.Module并实现了Autoregressive抽象基类:
class AutoregressiveModel(torch.nn.Module, Autoregressive): def __init__(self, d_latent: int = 128, n_tokens: int = 2**10): super().__init__() self.d_latent = d_latent # 潜在空间维度 self.n_tokens = n_tokens # token词汇表大小 self.L_max = 1024 # 最大序列长度 # 嵌入层 self.embedding = torch.nn.Embedding(num_embeddings=n_tokens, embedding_dim=d_latent) # 位置编码 self.pos_emb = torch.nn.Embedding(num_embeddings=self.L_max, embedding_dim=d_latent) # Transformer编码器 encoder_layer = torch.nn.TransformerEncoderLayer( d_model=d_latent, nhead=8, dim_feedforward=4*d_latent, activation="gelu", batch_first=True, norm_first=True, dropout=0.1 ) self.transformer = torch.nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=2, norm=torch.nn.LayerNorm(d_latent) ) # 输出层 self.fc_out = torch.nn.Linear(d_latent, n_tokens)3.2 因果掩码机制
自回归模型的关键是确保每个位置的预测只能依赖于之前的位置,这通过因果掩码实现:
def _generate_causal_mask(self, L: int, device: torch.device) -> torch.Tensor: """ 生成因果掩码:确保序列中第i个位置只能看到前i-1个位置 :param L: 序列长度 h*w :param device: 设备 :return: 掩码 (L, L),float型,上三角=-inf,下三角=0 """ mask = torch.nn.Transformer.generate_square_subsequent_mask(L, device=device) return mask这种掩码会阻止Transformer关注"未来"的token,保证生成过程的因果性。
4. 前向预测实现
4.1 前向传播流程
模型的前向传播过程包含以下步骤:
- 输入整形:将输入从(B, h, w)展平为(B, L),其中L=h*w
- token嵌入:通过Embedding层将整数token转换为连续向量
- 位置编码:为每个位置添加位置信息
- 序列右移:将整个序列向右移动一位,实现自回归特性
- Transformer编码:使用带因果掩码的Transformer处理序列
- 输出预测:通过线性层预测下一个token的概率分布
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: if x.dim() == 4: x = x.squeeze(1) B, h, w = x.shape L = h * w # 展平成序列 x_flat = x.reshape(B, L) # 嵌入 + 位置编码 token_emb = self.embedding(x_flat) pos_idx = torch.arange(L, device=x.device) pos_emb = self.pos_emb(pos_idx) x_emb = token_emb + pos_emb # 自回归右移(关键) x_emb = F.pad(x_emb, (0,0,1,0))[:, :-1] # 因果掩码 mask = self._generate_causal_mask(L, x.device) trans_out = self.transformer(x_emb, mask=mask) # 输出 logits = self.fc_out(trans_out) logits_2d = logits.reshape(B, h, w, self.n_tokens) return logits_2d, {}4.2 训练细节
模型训练使用标准的交叉熵损失函数:
criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(num_epochs): for batch in dataloader: tokens = batch.to(device) logits, _ = model(tokens) # 计算损失 loss = criterion(logits.view(-1, n_tokens), tokens.view(-1)) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step()训练过程中需要注意:
- 学习率不宜过大,建议从1e-4开始
- 可以使用学习率调度器动态调整
- 监控训练和验证损失,防止过拟合
5. 自回归生成实现
5.1 生成算法
自回归生成是从空序列开始,逐步预测每个位置的token:
@torch.no_grad() def generate(self, B: int = 1, h: int = 30, w: int = 20, device=None) -> torch.Tensor: device = device or next(self.parameters()).device L = h * w tokens = torch.zeros((B, L), dtype=torch.long, device=device) for i in range(L): # 获取当前序列的logits logits, _ = self(tokens.reshape(B, h, w)) logits = logits.reshape(B, L, -1) # 只取当前位置的预测 curr_logits = logits[:, i, :] # 采样下一个token probs = F.softmax(curr_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1) # 更新序列 if i < L - 1: tokens[:, i+1] = next_tokens.squeeze() return tokens.reshape(B, h, w)5.2 生成策略
在实际应用中,可以采用不同的生成策略:
贪心搜索:直接选择概率最大的token
next_tokens = torch.argmax(probs, dim=-1, keepdim=True)温度采样:通过温度参数控制生成的多样性
temperature = 0.7 scaled_logits = curr_logits / temperature probs = F.softmax(scaled_logits, dim=-1)Top-k采样:只从概率最高的k个token中采样
top_k = 40 values, indices = torch.topk(probs, top_k) probs = torch.zeros_like(probs).scatter_(-1, indices, values)
注意:温度参数越小,生成结果越确定;温度参数越大,生成结果越多样但可能不连贯。
6. 模型评估与结果
6.1 训练过程监控
训练过程中需要监控以下指标:
- 训练损失
- 验证损失
- 生成样本质量
典型的训练曲线如下图所示:
6.2 评估指标
除了常规的损失函数,还可以使用以下指标评估模型性能:
- 生成多样性:计算生成样本的token分布熵
- 重建质量:通过BSQ解码器将生成的token还原为图像,计算与原图的PSNR/SSIM
- 人类评估:人工评估生成图像的视觉质量
6.3 评分结果
项目评分系统给出的最终评估结果:
7. 实际应用与扩展
7.1 图像补全
该模型可用于图像补全任务:
- 给定部分图像token
- 使用自回归模型预测缺失部分
- 通过BSQ解码器还原完整图像
7.2 风格迁移
通过条件化生成,可以实现风格迁移:
- 在模型输入中添加风格编码
- 训练时使用风格分类器提供额外监督
- 生成时指定目标风格
7.3 模型优化方向
- 更大规模的训练:使用更多数据和更大模型
- 分层生成:先生成低分辨率图像,再逐步细化
- 混合架构:结合CNN和Transformer的优势
8. 常见问题与解决方案
8.1 训练不稳定
问题现象:损失值波动大,生成质量不一致
解决方案:
- 降低学习率
- 增加batch size
- 使用梯度裁剪
- 尝试不同的优化器(如AdamW)
8.2 生成重复模式
问题现象:生成图像出现重复的局部模式
解决方案:
- 增加温度参数
- 使用Top-k或Top-p采样
- 在训练数据中添加更多多样性
8.3 长序列生成质量差
问题现象:生成大尺寸图像时质量下降
解决方案:
- 使用相对位置编码
- 实现分块生成策略
- 增加Transformer层数
9. 工程实践建议
内存优化:对于大图像,使用梯度检查点减少内存占用
torch.utils.checkpoint.checkpoint(self.transformer, x_emb, mask)并行生成:利用GPU并行处理多个生成任务
@torch.no_grad() def batch_generate(self, B: int, h: int, w: int): # 批量生成实现 pass量化部署:使用TorchScript量化模型,提升推理速度
quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 )
在实际部署中,我发现将模型转换为ONNX格式可以显著提升推理速度,特别是在边缘设备上。具体做法是:
dummy_input = torch.zeros(1, h, w, dtype=torch.long) torch.onnx.export( model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size"}, "output": {0: "batch_size"} } )对于需要生成高分辨率图像的应用,建议采用分块生成策略:先将图像分成若干块分别生成,然后使用特殊的边界token确保块与块之间的连续性。这种方法可以突破Transformer序列长度的限制,同时保持生成质量。
