当前位置: 首页 > news >正文

从NLP跨界CV:手把手教你用PyTorch复现Vision Transformer (ViT) 图像分类

从NLP跨界CV:手把手教你用PyTorch复现Vision Transformer (ViT) 图像分类

当Transformer在自然语言处理领域大放异彩时,计算机视觉研究者们开始思考:这种基于自注意力机制的架构能否同样颠覆图像识别领域?2020年,Vision Transformer (ViT) 的出现给出了肯定答案。本文将带你从零开始,用PyTorch实现这一开创性模型,体验如何将图像转化为"视觉词汇"的奇妙过程。

1. ViT核心原理与设计思路

传统卷积神经网络(CNN)通过局部感受野逐层提取特征,而ViT则采用全局视角处理图像——它将输入图片分割为16x16的"视觉词汇块"(patches),每个块经过线性投影后成为Transformer可处理的序列元素。这种设计带来了三大关键创新:

  1. 图像序列化:将2D图像转换为1D令牌序列
  2. 位置编码:通过可学习的位置嵌入保留空间信息
  3. 纯Transformer架构:完全摒弃卷积操作

注意:ViT在中小型数据集上可能不如CNN表现优异,但当训练数据超过1亿张图片时,其性能开始显著超越传统方法。

下表对比了ViT与典型CNN的核心差异:

特性ViTCNN
特征提取方式全局自注意力局部卷积核
空间信息处理显式位置编码隐式感受野累积
数据依赖性需要大量训练数据中等规模数据即可
计算复杂度O(n²)O(n)

2. 环境准备与数据预处理

2.1 安装必要依赖

确保你的Python环境包含以下核心库:

pip install torch torchvision pytorch-lightning einops

2.2 CIFAR-10数据集处理

我们将使用CIFAR-10作为演示数据集。虽然原始ViT论文使用更大规模的ImageNet,但CIFAR-10更适合快速验证:

from torchvision import datasets, transforms # 定义数据增强策略 train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 train_data = datasets.CIFAR10('data', train=True, download=True, transform=train_transform) test_data = datasets.CIFAR10('data', train=False, transform=train_transform)

3. ViT模型实现详解

3.1 图像分块与线性嵌入

ViT的第一步是将图像分割为固定大小的块并线性投影到特征空间:

import torch import torch.nn as nn from einops import rearrange class PatchEmbedding(nn.Module): def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=64): super().__init__() self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) # [B, C, H, W] -> [B, D, H/P, W/P] x = rearrange(x, 'b d h w -> b (h w) d') return x

3.2 位置编码与分类令牌

Transformer需要位置信息来理解图像的空间结构:

class ViTEncoder(nn.Module): def __init__(self, num_patches, embed_dim, num_heads, num_layers): super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(embed_dim, num_heads), num_layers ) def forward(self, x): cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embed return self.transformer(x)

4. 完整模型组装与训练

4.1 构建端到端ViT模型

整合所有组件形成完整架构:

class VisionTransformer(nn.Module): def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=64, num_heads=4, num_layers=4, num_classes=10): super().__init__() self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) num_patches = (img_size // patch_size) ** 2 self.encoder = ViTEncoder(num_patches, embed_dim, num_heads, num_layers) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): x = self.patch_embed(x) x = self.encoder(x) return self.head(x[:, 0]) # 使用分类令牌输出

4.2 训练策略与超参数设置

使用PyTorch Lightning简化训练流程:

import pytorch_lightning as pl from torch.utils.data import DataLoader class ViTLightning(pl.LightningModule): def __init__(self, lr=1e-3): super().__init__() self.model = VisionTransformer() self.lr = lr self.criterion = nn.CrossEntropyLoss() def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch preds = self(x) loss = self.criterion(preds, y) self.log('train_loss', loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) # 初始化训练器 trainer = pl.Trainer(max_epochs=50, gpus=1 if torch.cuda.is_available() else 0) model = ViTLightning() # 数据加载器 train_loader = DataLoader(train_data, batch_size=64, shuffle=True) test_loader = DataLoader(test_data, batch_size=64) # 开始训练 trainer.fit(model, train_loader)

5. 模型优化与调参技巧

5.1 学习率调度策略

ViT训练对学习率非常敏感,推荐使用warmup策略:

def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.lr, total_steps=self.trainer.estimated_stepping_batches ) return [optimizer], [scheduler]

5.2 混合精度训练加速

利用NVIDIA GPU的Tensor Core加速训练:

trainer = pl.Trainer( max_epochs=50, precision=16, accelerator='gpu' if torch.cuda.is_available() else 'cpu' )

5.3 关键超参数经验值

基于CIFAR-10的实验验证,以下配置表现良好:

参数推荐值说明
patch_size4平衡计算量与局部信息保留
embed_dim64-128特征维度
num_heads4-8注意力头数
num_layers6-12Transformer层数
batch_size64-128根据GPU内存调整

6. 模型评估与结果分析

6.1 测试集性能评估

def test_step(self, batch, batch_idx): x, y = batch preds = self(x) loss = self.criterion(preds, y) acc = (preds.argmax(1) == y).float().mean() self.log('test_loss', loss) self.log('test_acc', acc) return {'loss': loss, 'acc': acc}

6.2 可视化注意力机制

理解模型如何关注图像不同区域:

import matplotlib.pyplot as plt def visualize_attention(model, img): model.eval() with torch.no_grad(): patches = model.patch_embed(img.unsqueeze(0)) attns = model.encoder.transformer.layers[0].self_attn( patches, patches, patches )[1] plt.imshow(attns[0, 0, 1:].reshape(8, 8).cpu()) plt.colorbar() plt.show()

在CIFAR-10上训练约50个epoch后,预期可以达到75-80%的测试准确率。虽然这低于原始论文在更大数据集上的结果,但足以验证ViT的基本原理。

http://www.jsqmd.com/news/964460/

相关文章:

  • 【题解】 ABC 461
  • 企业微信SCRM场景化盘点:采购负责人选型参考指南 - 资讯速览
  • 【CSDN AI引流卡片合规指南】:20年数字营销老兵亲测——微信/公众号链接能否放?3大红线+2份平台最新条款原文解读
  • 3个真实困境如何被一个脚本改写?揭秘网盘直链下载助手的底层逻辑
  • Agent-S3:首个超越人类性能的智能体框架技术解析与架构设计
  • Python 爬虫实战:分页循环爬取科普资讯基础实现方案
  • 5分钟搞定!Windows系统激活工具的终极使用指南
  • 基于 Harmony 6.0 应用的跑步配速教练应用首页实现
  • Windows/Mac通用教程:用旧版PS CS6和Acrobat Pro DC 2015,搞定超长网页截图打印(避坑指南)
  • 2026年 南通短视频运营/拍摄/获客/GEO推荐榜单:实战派团队与爆款创意口碑之选 - 企业推荐官【官方】
  • 2026年 南通短视频运营/拍摄/获客/GEO服务商推荐榜:实战派团队与创意爆款内容深度解析 - 企业推荐官【官方】
  • 别再死记硬背了!用一张外卖订单图,5分钟搞懂Hadoop MapReduce核心流程
  • 2026年徐州黄金回收行业发展指南:市场现状、交易流程与靠谱服务商盘点 - 寻茫精选
  • 2026年徐州黄金回收全指南:交易规则、避坑要点与靠谱服务方盘点 - 寻茫精选
  • 国产化替代实战:在统信UOS服务器上部署达梦DM8数据库的完整配置清单
  • 如何快速突破网盘限速:LinkSwift直链下载助手完整教程
  • 5分钟搞懂Guesslang:如何让AI一眼识别54种编程语言?
  • 揭秘藏品回收真相!北京丰宝斋告诉你,正规机构该有的样子 - 深鉴新闻
  • STM32F207多功能评估板设计:从离线编程到脚本化测试的硬件整合实践
  • Notepad2-mod深度解析:基于Scintilla引擎的轻量级编辑器架构剖析
  • 苏州拍婚纱照怎么选、多少钱、注意什么?一篇答疑 - eee888
  • 2026年网架厂家实力解析:徐州网架/煤棚网架/体育馆网架/大跨度网架/焊接球网架/螺栓球网架专业供应商深度解读 - 品牌企业推荐师(官方)
  • CE认证电缆厂家常见问题解答(2026最新专家版) - 资讯速览
  • KiTTY:解决Windows远程连接痛点的SSH客户端
  • 【2026必藏】6款智能降AI率网站大曝光,一键让AIGC率断崖式下跌! - 降AI小能手
  • 啤酒机气表常见问题解答(2026最新专家版) - 资讯速览
  • 深入AXI4-Lite总线:从AXI GPIO的寄存器读写,理解Zynq PL-PS数据交互的底层逻辑
  • HC-SR04超声波传感器Arduino一键测距库(带单位切换与稳定输出示例)
  • 万国手表全国售后服务网络升级公告 - 资讯速览
  • 2026年天津仓储货架供应厂家:重型/轻型/阁楼/智能货架,高效仓储与承重耐用之选 - 品牌企业推荐师(官方)