当前位置: 首页 > news >正文

保姆级教程:用PyTorch复现MAE(Masked Autoencoders)预训练ViT,附完整代码与避坑指南

从零实现MAE:PyTorch实战高比例掩码自监督预训练

在计算机视觉领域,自监督学习正逐渐成为获取强大视觉表征的主流范式。2022年ICCV最佳论文MAE(Masked Autoencoders)提出了一种简单而高效的预训练方法,通过随机掩码75%的图像块并重建原始像素,使ViT模型在ImageNet-1K上达到了87.8%的top-1准确率。本文将带您从工程角度完整实现MAE预训练流程,涵盖以下关键环节:

1. 环境配置与数据准备

首先需要搭建适合大规模训练的PyTorch环境。推荐使用Python 3.8+和PyTorch 1.12+版本,同时安装timm库以获取ViT实现:

conda create -n mae python=3.8 conda activate mae pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install timm==0.6.12

对于数据集处理,MAE原论文使用ImageNet-1K,但为快速验证我们可以选择CIFAR-10或Tiny-ImageNet。以下代码展示了自定义数据加载器的关键步骤:

from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) class MAEDataset(torch.utils.data.Dataset): def __init__(self, original_dataset): self.dataset = original_dataset def __getitem__(self, index): img, _ = self.dataset[index] # 忽略原始标签 return train_transform(img)

2. 核心架构实现

2.1 Patch嵌入与位置编码

MAE首先将图像分割为不重叠的patch(典型尺寸16×16),然后线性投影为token:

import torch.nn as nn class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.num_patches = (img_size // patch_size) ** 2 def forward(self, x): x = self.proj(x) # [B, C, H, W] -> [B, D, H/P, W/P] x = x.flatten(2).transpose(1, 2) # [B, D, N] -> [B, N, D] return x

位置编码采用可学习的1D向量,与ViT保持一致:

class PositionalEncoding(nn.Module): def __init__(self, num_patches, embed_dim): super().__init__() self.pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim)) def forward(self, x): return x + self.pos_embed

2.2 非对称编解码器设计

MAE的核心创新在于其非对称架构——编码器仅处理可见patch,而解码器重建全部patch:

class MAEEncoder(nn.Module): def __init__(self, embed_dim, depth, num_heads): super().__init__() self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads) for _ in range(depth) ]) def forward(self, x, mask_ratio=0.75): # 随机生成mask (实现细节见2.3节) B, N, D = x.shape len_keep = int(N * (1 - mask_ratio)) # 仅保留未mask的token ids_keep = torch.argsort(noise, dim=1)[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D)) # 通过Transformer块 for blk in self.blocks: x_masked = blk(x_masked) return x_masked, ids_keep

解码器需要处理完整的token序列(含mask token):

class MAEDecoder(nn.Module): def __init__(self, embed_dim, decoder_dim, num_patches): super().__init__() self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) self.decoder_pos = PositionalEncoding(num_patches, decoder_dim) self.decoder_blocks = nn.ModuleList([ TransformerBlock(decoder_dim, num_heads=4) for _ in range(4) ]) self.head = nn.Linear(decoder_dim, 3*16*16) # 重建16x16 RGB patch def forward(self, x, ids_restore): # 将mask token插入编码器输出 mask_tokens = self.mask_token.repeat( x.shape[0], ids_restore.shape[1] - x.shape[1], 1) x_ = torch.cat([x, mask_tokens], dim=1) # 恢复原始顺序 x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x.shape[2])) # 添加位置编码并通过解码器 x_ = self.decoder_pos(x_) for blk in self.decoder_blocks: x_ = blk(x_) return self.head(x_)

2.3 掩码生成与序列恢复

实现高比例随机掩码需要注意以下关键点:

def random_masking(x, mask_ratio): B, N, D = x.shape len_keep = int(N * (1 - mask_ratio)) noise = torch.rand(B, N, device=x.device) # 均匀分布噪声 ids_shuffle = torch.argsort(noise, dim=1) # 升序排列 ids_restore = torch.argsort(ids_shuffle, dim=1) # 恢复索引 # 生成二进制mask (0保留, 1丢弃) mask = torch.ones([B, N], device=x.device) mask[:, :len_keep] = 0 mask = torch.gather(mask, dim=1, index=ids_restore) return ids_shuffle, ids_restore, mask

注意:MAE的掩码策略与BERT不同,不需要特殊[mask]标记,而是直接移除被mask的patch。这使得编码器计算量减少约75%。

3. 训练流程与损失计算

3.1 像素重建目标

MAE使用MSE损失,但仅计算被mask区域的像素误差:

class MAE(nn.Module): def __init__(self): super().__init__() self.patch_embed = PatchEmbed() self.encoder = MAEEncoder(depth=12, embed_dim=1024, num_heads=16) self.decoder = MAEDecoder(embed_dim=1024, decoder_dim=512) def forward(self, imgs, mask_ratio=0.75): # 图像分块 patches = self.patch_embed(imgs) # [B, N, D] # 编码可见patch x_encoded, ids_restore = self.encoder(patches, mask_ratio) # 解码重建 x_recon = self.decoder(x_encoded, ids_restore) # 计算mask区域MSE target = self.patchify(imgs) loss = (x_recon - target) ** 2 loss = loss.mean(dim=-1) # 各patch的均方误差 mask = self.get_mask(ids_restore, mask_ratio) loss = (loss * mask).sum() / mask.sum() # 仅mask区域 return loss

3.2 关键训练技巧

实际训练时需要特别注意以下超参数设置:

参数推荐值作用
学习率1.5e-4使用AdamW优化器
批量大小4096需多GPU分布式训练
热身epoch40线性学习率预热
权重衰减0.05防止过拟合
掩码比例75%论文最优值

分布式训练启动脚本示例:

python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=4 --node_rank=$RANK \ train_mae.py --batch_size 512 --accum_iter 8

4. 下游任务迁移

4.1 分类任务微调

预训练完成后,只需保留编码器并添加分类头:

from timm.models.vision_transformer import VisionTransformer class MAEForClassification(nn.Module): def __init__(self, pretrained_encoder): super().__init__() self.encoder = pretrained_encoder self.head = nn.Linear(1024, num_classes) def forward(self, x): # 完整图像通过编码器 patches = self.patch_embed(x) x = self.encoder(patches, mask_ratio=0) # 无mask # 使用class token或平均池化 return self.head(x.mean(dim=1))

4.2 目标检测适配

对于检测任务如Mask R-CNN,可将MAE编码器作为backbone:

def build_mae_backbone(cfg): from detectron2.modeling import Backbone class MAEBackbone(Backbone): def __init__(self, pretrained_encoder): super().__init__() self.encoder = pretrained_encoder self._out_features = ["block4", "block8", "block12"] def forward(self, x): features = {} x = self.encoder.patch_embed(x) for i, blk in enumerate(self.encoder.blocks): x = blk(x) if f"block{i+1}" in self._out_features: features[f"block{i+1}"] = x.permute(0, 2, 1).unflatten(2, (14, 14)) return features

在实现过程中,最容易出现的维度错误通常发生在以下环节:

  • patch嵌入后的维度转换(需确保从[B,C,H,W]到[B,N,D]的正确变形)
  • mask token与编码输出的拼接(需严格对齐序列位置)
  • 损失计算时的mask应用(需确保只计算被mask区域的像素)

经过完整训练周期后,建议通过可视化重建结果验证模型性能。良好的重建效果表明编码器已学习到有意义的视觉表征,即使被mask区域占原始图像的75%。

http://www.jsqmd.com/news/849030/

相关文章:

  • Zotero引文格式终极自定义指南:从IEEE期刊简称到会议名缩写,一篇搞定所有细节
  • Git基本操作(四):删除文件
  • AdBlock 自定义规则
  • 3步掌握Navicat无限试用重置:Mac用户的完整专业指南
  • 化工行业节能改造数据监测系统方案
  • 《CVPR2025-DEIM创新改进项目实战:从原理到部署的深度学习优化全攻略》004、DEIM数学基础:注意力机制与特征重标定的统一框架
  • 企业信息化架构(业务架构、应用架构、数据架构、技术架构)方案:四横五纵框架 、元模型+视图 、业务、应用、数据、技术四大架构
  • ncmdump终极解密指南:3分钟解锁网易云加密音乐文件
  • VIGOR:跨越“一对一”检索的理想假设,面向真实场景的跨视角地理定位数据集
  • 从堆叠到双线性:手把手图解注意力机制的‘进化史’与PyTorch实现对比
  • Python异步编程模式:从同步到异步的演进
  • AUTO TECH China 2026广州汽车零部件展:从整机集成迈向核心部件的产业跃升
  • 镜像视界(浙江)科技有限公司|空间智能·视频孪生·无感定位·跨镜跟踪
  • 别再死记硬背了!用Python的Matplotlib亲手画一遍sinx、cosx、tanx等函数图像,理解更深刻
  • 《CVPR2025-DEIM创新改进项目实战:从原理到部署的深度学习优化全攻略》005、DEIM模型架构总览——编码器-解码器与动态门控设计
  • DFT笔记57
  • 分支管理(一):创建、切换与合并,体验“平行宇宙”
  • 告别理论!5分钟用PyWavelets搞定二维离散小波变换(2D-DWT)的Python代码实战
  • 你的电机为什么抖?排查STM32F4 PWM驱动TB6612的5个常见硬件坑(附示波器实测)
  • 告别GDB依赖:在NEMU里打造专属调试器,我是如何搞定单步执行与内存扫描的
  • Rust内存安全:所有权、借用与生命周期深度解析
  • SWAT模型高阶十七项案例分析实践技术
  • 别再用理想模型了!用TINA-TI仿真μA741驱动容性负载,实测振铃现象与消除方案
  • AnyVisLoc:专为低空多视角无人机定位打造的全球首个统一评测基准
  • 如何监控 RabbitMQ 队列长度实现自动告警
  • 别再只会用关键词了!这10个Google搜索命令,让你找资料效率翻倍(附实战案例)
  • 插件:Custom Attachment Location 图片自定义
  • 不用真飞机!用BetaFlight遥控器玩转PX4无人机仿真:QGC配置与手动飞行入门
  • 别再死记硬背物联网四层架构了!用LoRa和ESP32手把手搭个智能花盆,实战理解每一层
  • ARM SPE统计性能分析扩展与缓冲区管理机制详解