保姆级教程:手把手复现MAE(Masked Autoencoder)图像预训练(PyTorch版)
从零实现MAE图像预训练:PyTorch实战指南
1. 环境准备与数据预处理
在开始构建MAE模型之前,我们需要搭建合适的开发环境并准备数据集。以下是完整的配置流程:
基础环境要求:
- Python 3.8+
- PyTorch 1.10+
- CUDA 11.3(如使用GPU加速)
- torchvision 0.11+
# 创建conda环境 conda create -n mae python=3.8 conda activate mae # 安装PyTorch pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 # 安装其他依赖 pip install timm matplotlib numpy pandas对于数据集处理,我们将使用ImageNet-1k作为示例。实际应用中可根据需求替换为其他图像数据集:
import torch from torchvision import datasets, transforms # 定义数据增强策略 train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载数据集 train_dataset = datasets.ImageFolder( root='path/to/imagenet/train', transform=train_transform ) val_dataset = datasets.ImageFolder( root='path/to/imagenet/val', transform=val_transform ) # 创建数据加载器 batch_size = 256 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True )2. MAE核心组件实现
2.1 Patch嵌入与位置编码
MAE首先将图像分割为规则的patch网格,这是Vision Transformer的标准处理方式:
import torch.nn as nn import math class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x): x = self.proj(x) # (B, E, H/P, W/P) x = x.flatten(2) # (B, E, N) x = x.transpose(1, 2) # (B, N, E) return x class PositionEmbedding(nn.Module): def __init__(self, n_patches, embed_dim, dropout=0.1): super().__init__() self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim)) self.dropout = nn.Dropout(p=dropout) # 初始化位置编码 nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, x): x = x + self.pos_embed return self.dropout(x)2.2 随机掩码生成
MAE的核心创新之一是采用高比例随机掩码策略:
def random_masking(x, mask_ratio=0.75): """ x: (B, N, E) - 输入patch序列 mask_ratio: 掩码比例 返回: x_masked: 可见patch mask: 二进制掩码 (1表示保留, 0表示掩码) ids_restore: 用于恢复原始顺序的索引 """ B, N, E = x.shape len_keep = int(N * (1 - mask_ratio)) # 生成随机噪声并排序 noise = torch.rand(B, N, device=x.device) ids_shuffle = torch.argsort(noise, dim=1) ids_restore = torch.argsort(ids_shuffle, dim=1) # 保留前len_keep个patch ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, E)) # 生成二进制掩码 (0表示掩码, 1表示保留) mask = torch.zeros([B, N], device=x.device) mask[:, :len_keep] = 1 mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore2.3 Transformer编码器-解码器架构
MAE采用非对称的编码器-解码器设计:
class TransformerEncoder(nn.Module): def __init__(self, embed_dim, depth, num_heads, mlp_ratio=4.): super().__init__() self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth) ]) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): for blk in self.blocks: x = blk(x) return self.norm(x) class TransformerDecoder(nn.Module): def __init__(self, embed_dim, decoder_embed_dim, depth, num_heads, mlp_ratio=4.): super().__init__() # 输入投影 self.proj = nn.Linear(embed_dim, decoder_embed_dim) # 掩码标记 self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) nn.init.normal_(self.mask_token, std=0.02) self.blocks = nn.ModuleList([ TransformerBlock(decoder_embed_dim, num_heads, mlp_ratio) for _ in range(depth) ]) self.norm = nn.LayerNorm(decoder_embed_dim) self.head = nn.Linear(decoder_embed_dim, 3 * 16**2) # 预测RGB像素值 def forward(self, x, ids_restore): # 嵌入掩码标记 mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) x_ = torch.cat([x, mask_tokens], dim=1) x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x.shape[2])) # 应用Transformer块 x_ = self.proj(x_) for blk in self.blocks: x_ = blk(x_) x_ = self.norm(x_) return self.head(x_)3. 完整MAE模型集成
将上述组件组合成完整的MAE模型:
class MAE(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, encoder_depth=12, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., mask_ratio=0.75): super().__init__() self.mask_ratio = mask_ratio # Patch嵌入 self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) n_patches = self.patch_embed.n_patches # 位置编码 self.pos_embed = PositionEmbedding(n_patches, embed_dim) # 编码器和解码器 self.encoder = TransformerEncoder(embed_dim, encoder_depth, num_heads, mlp_ratio) self.decoder = TransformerDecoder( embed_dim, decoder_embed_dim, decoder_depth, decoder_num_heads, mlp_ratio ) # 初始化参数 self.initialize_weights() def initialize_weights(self): # 初始化位置嵌入和掩码标记 nn.init.trunc_normal_(self.pos_embed.pos_embed, std=0.02) nn.init.normal_(self.decoder.mask_token, std=0.02) # 初始化线性投影 nn.init.xavier_uniform_(self.decoder.proj.weight) nn.init.zeros_(self.decoder.proj.bias) nn.init.xavier_uniform_(self.decoder.head.weight) nn.init.zeros_(self.decoder.head.bias) def forward(self, x): # 嵌入patch x = self.patch_embed(x) x = self.pos_embed(x) # 随机掩码 x_masked, mask, ids_restore = random_masking(x, self.mask_ratio) # 编码 latent = self.encoder(x_masked) # 解码 pred = self.decoder(latent, ids_restore) return pred, mask4. 训练策略与损失函数
MAE的训练需要特殊的损失计算和优化策略:
def train_mae(model, train_loader, optimizer, epoch, device): model.train() total_loss = 0 for batch_idx, (images, _) in enumerate(train_loader): images = images.to(device) # 前向传播 pred, mask = model(images) # 计算损失(仅在掩码patch上) target = model.patch_embed(images) target = target.detach() # 归一化目标(可选) mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1e-6)**0.5 # 仅计算掩码patch的损失 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], 每个patch的损失 loss = (loss * (1 - mask)).sum() / (1 - mask).sum() # 平均仅对掩码patch # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() if batch_idx % 100 == 0: print(f'Train Epoch: {epoch} [{batch_idx * len(images)}/{len(train_loader.dataset)}]' f'\tLoss: {loss.item():.4f}') avg_loss = total_loss / len(train_loader) print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}') return avg_loss关键训练参数配置:
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 基础学习率 | 1.5e-4 | 使用线性缩放规则(lr = base_lr * batch_size / 256) |
| 优化器 | AdamW | 权重衰减0.05 |
| 训练周期 | 400-1600 | 更长训练通常带来更好效果 |
| 批量大小 | 256-2048 | 根据GPU内存调整 |
| 学习率调度 | 余弦衰减 | 带warmup(40周期) |
| 权重衰减 | 0.05 | 防止过拟合 |
# 初始化模型和优化器 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = MAE( img_size=224, patch_size=16, embed_dim=768, encoder_depth=12, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, mask_ratio=0.75 ).to(device) optimizer = torch.optim.AdamW( model.parameters(), lr=1.5e-4 * 256 / 256, # 基础学习率按batch_size缩放 betas=(0.9, 0.95), weight_decay=0.05 ) # 学习率调度器 lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=400, eta_min=1e-6 ) # 训练循环 for epoch in range(1, 401): train_mae(model, train_loader, optimizer, epoch, device) lr_scheduler.step()5. 模型评估与应用
5.1 可视化重建效果
import matplotlib.pyplot as plt def visualize_reconstruction(model, val_loader, device, num_examples=5): model.eval() with torch.no_grad(): for images, _ in val_loader: images = images.to(device) pred, mask = model(images) # 获取原始patch patches = model.patch_embed(images) B, N, C = patches.shape patch_size = model.patch_embed.patch_size # 重建图像 pred_patches = pred.reshape(B, N, 3, patch_size, patch_size) pred_patches = pred_patches.permute(0, 2, 3, 1, 4).reshape(B, 3, 224, 224) # 反归一化 mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) images = images * std + mean pred_patches = pred_patches * std + mean # 可视化 fig, axes = plt.subplots(num_examples, 2, figsize=(10, num_examples*5)) for i in range(num_examples): # 原始图像(带掩码) masked_img = images[i].cpu().numpy().transpose(1, 2, 0) mask_ = mask[i].cpu().numpy().reshape(14, 14) mask_ = torch.nn.functional.interpolate( torch.from_numpy(mask_).float()[None, None, ...], scale_factor=16, mode='nearest' )[0, 0].numpy() masked_img = masked_img * mask_[..., None] # 重建图像 recon_img = pred_patches[i].cpu().numpy().transpose(1, 2, 0) recon_img = np.clip(recon_img, 0, 1) axes[i, 0].imshow(masked_img) axes[i, 0].set_title('Masked Input') axes[i, 0].axis('off') axes[i, 1].imshow(recon_img) axes[i, 1].set_title('Reconstruction') axes[i, 1].axis('off') plt.tight_layout() plt.show() break visualize_reconstruction(model, val_loader, device)5.2 下游任务迁移学习
预训练完成后,我们可以将编码器用于各种下游任务:
class FineTuneModel(nn.Module): def __init__(self, encoder, num_classes): super().__init__() self.encoder = encoder self.head = nn.Linear(encoder.encoder.embed_dim, num_classes) # 冻结编码器参数(可选) for param in self.encoder.parameters(): param.requires_grad = False def forward(self, x): # 获取patch嵌入 x = self.encoder.patch_embed(x) x = self.encoder.pos_embed(x) # 通过编码器(不使用掩码) features = self.encoder.encoder(x) # 全局平均池化 features = features.mean(dim=1) # 分类头 return self.head(features) # 初始化微调模型 finetune_model = FineTuneModel(model, num_classes=1000).to(device) # 微调训练(示例) def train_finetune(model, train_loader, optimizer, criterion, epoch, device): model.train() total_loss = 0 correct = 0 for batch_idx, (images, labels) in enumerate(train_loader): images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() _, predicted = outputs.max(1) correct += predicted.eq(labels).sum().item() if batch_idx % 100 == 0: print(f'Train Epoch: {epoch} [{batch_idx * len(images)}/{len(train_loader.dataset)}]' f'\tLoss: {loss.item():.4f}\tAcc: {100. * correct / ((batch_idx + 1) * len(images)):.2f}%') avg_loss = total_loss / len(train_loader) accuracy = 100. * correct / len(train_loader.dataset) print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}\tAccuracy: {accuracy:.2f}%') return avg_loss, accuracy6. 高级技巧与优化
6.1 渐进式掩码比例
训练初期使用较低掩码比例,逐步增加到目标比例:
def get_current_mask_ratio(epoch, max_epochs, final_ratio=0.75): """线性增加掩码比例""" start_ratio = 0.5 return min(final_ratio, start_ratio + (final_ratio - start_ratio) * (epoch / max_epochs)) # 在训练循环中 current_mask_ratio = get_current_mask_ratio(epoch, 400) pred, mask = model(images, current_mask_ratio)6.2 学习率warmup
def adjust_learning_rate(optimizer, epoch, max_epochs, base_lr): """线性warmup然后余弦衰减""" warmup_epochs = 40 if epoch < warmup_epochs: lr = base_lr * epoch / warmup_epochs else: lr = base_lr * 0.5 * (1. + math.cos(math.pi * (epoch - warmup_epochs) / (max_epochs - warmup_epochs))) for param_group in optimizer.param_groups: param_group['lr'] = lr6.3 混合精度训练
from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() def train_mae_amp(model, train_loader, optimizer, epoch, device): model.train() total_loss = 0 for batch_idx, (images, _) in enumerate(train_loader): images = images.to(device) optimizer.zero_grad() with autocast(): pred, mask = model(images) target = model.patch_embed(images).detach() loss = ((pred - target) ** 2).mean(dim=-1) loss = (loss * (1 - mask)).sum() / (1 - mask).sum() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() if batch_idx % 100 == 0: print(f'Train Epoch: {epoch} [{batch_idx * len(images)}/{len(train_loader.dataset)}]' f'\tLoss: {loss.item():.4f}') avg_loss = total_loss / len(train_loader) print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}') return avg_loss7. 实际应用中的注意事项
硬件要求:
- ViT-Base (ViT-B) 需要约16GB GPU内存(批量大小256)
- ViT-Large (ViT-L) 需要约32GB GPU内存
- 考虑使用梯度累积技术减少内存需求
训练时间优化:
- 使用混合精度训练(AMP)加速
- 采用分布式数据并行(DDP)进行多GPU训练
- 预加载数据到内存减少I/O等待
模型保存与加载:
# 保存完整模型 torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'loss': loss, }, 'mae_checkpoint.pth') # 加载模型 checkpoint = torch.load('mae_checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch']- 调试技巧:
- 监控重建损失和可视化结果
- 检查梯度范数防止梯度爆炸/消失
- 使用更小的模型和数据集进行快速原型验证
8. 扩展应用与变体
8.1 多模态MAE
将MAE扩展到多模态数据,如图像-文本对:
class MultiModalMAE(nn.Module): def __init__(self, image_config, text_config): super().__init__() self.image_mae = MAE(**image_config) self.text_mae = TextMAE(**text_config) self.cross_modal_head = nn.Linear( image_config['embed_dim'] + text_config['embed_dim'], image_config['embed_dim'] # 或其他设计 ) def forward(self, images, text): image_latent = self.image_mae.encoder(images) text_latent = self.text_mae.encoder(text) # 跨模态融合 joint_latent = torch.cat([image_latent.mean(dim=1), text_latent.mean(dim=1)], dim=1) joint_latent = self.cross_modal_head(joint_latent) return joint_latent8.2 分层MAE
class HierarchicalMAE(nn.Module): def __init__(self): super().__init__() self.stage1 = MAE(img_size=224, patch_size=16) self.stage2 = MAE(img_size=112, patch_size=8) self.merge = nn.Linear(768*2, 768) def forward(self, x): # 第一阶段:低分辨率处理 x_low = F.interpolate(x, size=112, mode='bilinear') h1 = self.stage1(x_low) # 第二阶段:高分辨率处理 h2 = self.stage2(x) # 合并特征 return self.merge(torch.cat([h1, h2], dim=-1))9. 性能优化技巧
内存优化:
- 使用梯度检查点技术
- 采用更高效的自注意力实现(如FlashAttention)
- 减少不必要的中间变量保存
计算优化:
- 使用torch.compile()(PyTorch 2.0+)
- 优化矩阵乘法顺序
- 利用CUDA核心的Tensor Core
批处理策略:
- 动态批处理(根据图像大小)
- 使用NVIDIA DALI加速数据加载
- 预计算静态图(JIT编译)
# 使用PyTorch 2.0的编译功能 model = torch.compile(model, mode='max-autotune')10. 常见问题解决方案
训练不稳定:
- 添加梯度裁剪(
torch.nn.utils.clip_grad_norm_) - 使用更小的学习率或更长的warmup
- 尝试不同的初始化策略
- 添加梯度裁剪(
过拟合:
- 增加掩码比例(最高可达90%)
- 使用更强的数据增强
- 添加DropPath(随机深度)正则化
收敛慢:
- 检查学习率调度
- 验证数据预处理是否正确
- 尝试更大的模型容量
GPU内存不足:
- 减少批量大小
- 使用梯度累积
- 尝试模型并行或更高效的注意力实现
# 梯度累积示例 accum_steps = 4 optimizer.zero_grad() for i, (images, _) in enumerate(train_loader): images = images.to(device) with autocast(): pred, mask = model(images) target = model.patch_embed(images).detach() loss = ((pred - target) ** 2).mean(dim=-1) loss = (loss * (1 - mask)).sum() / (1 - mask).sum() loss = loss / accum_steps scaler.scale(loss).backward() if (i + 1) % accum_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad()