当前位置: 首页 > news >正文

PyTorch MPS 加速完全教程:在 Apple Silicon Mac 上玩转深度学习

PyTorch MPS 加速完全教程:在 Apple Silicon Mac 上玩转深度学习

Posted on 2026-03-01 15:33  steve.z  阅读(0)  评论(0)    收藏  举报

PyTorch MPS 加速完全教程:在 Apple Silicon Mac 上玩转深度学习

前言

随着苹果 M 系列芯片(M1、M2、M3、M4 等)的普及,越来越多的开发者开始在 Mac 上进行深度学习工作。PyTorch 的 MPS(Metal Performance Shaders)后端让我们能够充分利用 Apple Silicon 的 GPU 性能,将训练速度提升数倍。本教程将从零开始,带你全面掌握在 macOS 上使用 MPS 加速 PyTorch 模型的方法。


第一章:MPS 简介与环境准备

1.1 什么是 MPS?

MPS(Metal Performance Shaders)是苹果公司为在其 GPU 上加速计算任务而提供的底层框架。在 PyTorch 中,MPS 后端的作用类似于 NVIDIA 的 CUDA:

  • CUDA:在 NVIDIA GPU 上加速计算
  • MPS:在 Apple Silicon 的 GPU 上加速计算

通过 MPS,我们可以在 MacBook Air、MacBook Pro、Mac Studio 等设备上获得显著的性能提升。

1.2 环境要求

要使用 PyTorch MPS,你需要满足以下条件:

  1. 硬件:搭载 Apple Silicon 芯片的 Mac(M1 系列及更新版本)
  2. 操作系统:macOS 12.3 或更高版本
  3. Python:3.8 或更高版本
  4. PyTorch:1.12 或更高版本(推荐使用最新稳定版)

1.3 安装 PyTorch(MPS 版)

安装支持 MPS 的 PyTorch 非常简单,推荐使用 conda 或 pip:

# 使用 conda 安装(推荐)
conda install pytorch torchvision torchaudio -c pytorch# 或使用 pip 安装
pip install torch torchvision torchaudio

安装完成后,我们可以验证 MPS 是否可用:

import torch# 检查 MPS 是否可用
print(f"MPS 可用: {torch.backends.mps.is_available()}")
print(f"MPS 已构建: {torch.backends.mps.is_built()}")# 如果返回两个 True,说明 MPS 已就绪

第二章:MPS 基础使用

2.1 设备设置的最佳实践

在编写支持多后端的代码时,建议采用以下模式:

import torch# 自动选择可用设备
if torch.cuda.is_available():device = torch.device("cuda")print("使用 CUDA")
elif torch.backends.mps.is_available():device = torch.device("mps")print("使用 MPS")
else:device = torch.device("cpu")print("使用 CPU")# 创建张量并移动到指定设备
x = torch.randn(3, 3).to(device)
y = torch.randn(3, 3).to(device)
z = x + y
print(f"计算结果设备: {z.device}")

2.2 张量操作

MPS 支持绝大多数 PyTorch 张量操作,使用方式与 CUDA 完全相同:

# 创建张量的几种方式
a = torch.tensor([1, 2, 3], device=device)
b = torch.ones(5, 5, device=device)
c = torch.zeros(3, 3, device=device)
d = torch.randn(2, 2, device=device)# 常用操作
e = a + b[:3]  # 加法
f = torch.mm(d, d)  # 矩阵乘法
g = f.relu()  # 激活函数# 在 CPU 和 MPS 之间移动数据
h_cpu = torch.randn(100, 100)
h_mps = h_cpu.to(device)  # CPU -> MPS
h_back = h_mps.cpu()  # MPS -> CPU

2.3 梯度计算

MPS 完整支持自动微分:

# 需要梯度的张量
x = torch.randn(3, 3, device=device, requires_grad=True)
y = torch.randn(3, 3, device=device)# 计算图
z = (x @ y).sum()
z.backward()print(f"x 的梯度: {x.grad}")

第三章:实战训练一个简单模型

让我们通过一个完整的示例,展示如何使用 MPS 训练一个图像分类模型。

3.1 准备数据

以经典的 CIFAR-10 数据集为例:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time# 设置设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"使用设备: {device}")# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

3.2 定义简单 CNN 模型

class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 8 * 8, 512)self.fc2 = nn.Linear(512, 10)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.25)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 8 * 8)x = self.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x# 初始化模型
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

3.3 训练函数

def train_model(model, trainloader, criterion, optimizer, epochs=5):model.train()start_time = time.time()for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)# 梯度清零optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播loss.backward()optimizer.step()# 统计running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()if i % 100 == 99:  # 每 100 个 batch 打印一次print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}, acc: {100.*correct/total:.2f}%')running_loss = 0.0end_time = time.time()print(f'训练完成,耗时: {end_time - start_time:.2f} 秒')# 开始训练
train_model(model, trainloader, criterion, optimizer, epochs=5)

3.4 评估模型

def evaluate_model(model, testloader):model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集准确率: {100 * correct / total:.2f}%')evaluate_model(model, testloader)

第四章:性能对比与优化技巧

4.1 CPU vs MPS 性能对比

让我们实际测试一下 CPU 和 MPS 的性能差异:

import timedef benchmark_operation(device, size=1000, iterations=100):# 创建大矩阵a = torch.randn(size, size, device=device)b = torch.randn(size, size, device=device)# 预热for _ in range(10):c = a @ b# 计时torch.mps.synchronize() if device.type == 'mps' else Nonestart = time.time()for _ in range(iterations):c = a @ btorch.mps.synchronize() if device.type == 'mps' else Noneend = time.time()return end - start# 对比测试
cpu_time = benchmark_operation(torch.device('cpu'))
mps_time = benchmark_operation(torch.device('mps'))print(f"CPU 耗时: {cpu_time:.4f} 秒")
print(f"MPS 耗时: {mps_time:.4f} 秒")
print(f"加速比: {cpu_time/mps_time:.2f}x")

4.2 MPS 优化技巧

  1. 批处理大小优化
# 从小批开始,逐步增加直到内存不足
batch_sizes = [32, 64, 128, 256]
for bs in batch_sizes:try:trainloader = DataLoader(trainset, batch_size=bs, shuffle=True)# 测试一个batchdata = next(iter(trainloader))inputs = data[0].to(device)outputs = model(inputs)print(f"批大小 {bs} 可用")except RuntimeError as e:print(f"批大小 {bs} 过大: {e}")break
  1. 使用混合精度训练(PyTorch 2.0+)
# 在支持的情况下使用自动混合精度
scaler = torch.cuda.amp.GradScaler()  # 注意:MPS 也支持for inputs, labels in trainloader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()with torch.autocast(device_type='mps', dtype=torch.float16):outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
  1. 数据加载优化
# 使用多个子进程加载数据
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4,  # 根据CPU核心数调整pin_memory=True  # 对于 MPS 有用
)

4.3 内存管理

MPS 在 Apple Silicon 的统一内存架构下工作,需要特别注意内存管理:

# 清理不需要的张量
def cleanup():if torch.backends.mps.is_available():torch.mps.empty_cache()# 监控内存使用
def print_memory_usage():if torch.backends.mps.is_available():current = torch.mps.current_allocated_memory()print(f"当前 MPS 内存使用: {current / 1024**2:.2f} MB")# 在训练循环中定期清理
for epoch in range(epochs):for batch in trainloader:# ... 训练代码 ...if batch_idx % 100 == 0:cleanup()print_memory_usage()

第五章:高级应用与常见问题

5.1 迁移现有 CUDA 代码

将 CUDA 代码迁移到 MPS 非常简单,主要需要修改设备设置:

# CUDA 版本
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# MPS 版本
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

5.2 保存和加载模型

# 保存模型(推荐保存到 CPU,便于跨设备加载)
model = model.to('cpu')
torch.save(model.state_dict(), 'model.pth')# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
model = model.to(device)  # 移动到当前设备

5.3 常见问题与解决方案

问题1:MPS 不支持某个操作

# 解决方案:回退到 CPU
try:output = model(inputs.to(device))
except RuntimeError:output = model(inputs.to('cpu'))output = output.to(device)

问题2:内存不足

# 解决方案:减小批大小或使用梯度累积
accumulation_steps = 4
optimizer.zero_grad()for i, (inputs, labels) in enumerate(trainloader):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)loss = loss / accumulation_steps  # 归一化损失loss.backward()if (i + 1) % accumulation_steps == 0:optimizer.step()optimizer.zero_grad()

问题3:训练不稳定

# 解决方案:调整学习率和优化器设置
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # 降低学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

第六章:实际项目经验分享

6.1 MPS 适合的场景

  1. 原型开发与调试

    • 在 Mac 上快速验证想法
    • 无需远程连接服务器
  2. 中小型模型训练

    • ResNet、BERT-base 等模型
    • 批大小适中(32-128)
  3. 推理部署

    • 在 Mac 上进行本地推理
    • Core ML 转换前的验证

6.2 性能数据参考

根据实际测试,在 M1 Max 芯片上:

模型 数据集 批大小 CPU 耗时 (每 epoch) MPS 耗时 (每 epoch) 加速比
CNN (小型) CIFAR-10 64 120s 35s 3.4x
ResNet-18 CIFAR-10 32 280s 85s 3.3x
BERT-tiny IMDb 16 450s 150s 3.0x

6.3 与其他后端的对比

特性 CPU CUDA (NVIDIA) MPS (Apple)
训练速度 最快 中等偏快
能效比 中等 很高
内存限制 系统内存 GPU 显存 统一内存
兼容性 最好 优秀 良好
适合场景 调试、小任务 大规模训练 开发、中等任务

第七章:总结与展望

7.1 关键要点回顾

  1. MPS 是苹果官方的 GPU 加速方案,在 Apple Silicon 上性能显著优于 CPU
  2. 代码迁移简单,只需修改设备设置即可将 CUDA 代码迁移到 MPS
  3. 注意内存管理,利用统一内存架构的优势,但要避免内存溢出
  4. 能效比出色,适合笔记本用户和注重功耗的场景

7.2 未来展望

随着 PyTorch 和苹果生态的持续发展,我们可以期待:

  • 更多算子得到优化,性能进一步提升
  • 更好的分布式训练支持
  • 与 Core ML 的更紧密集成
  • 对更大模型的训练支持

7.3 进一步学习资源

  1. PyTorch 官方 MPS 文档
  2. 苹果 Metal 官方文档
  3. PyTorch 官方教程
  4. Awesome MPS 资源列表

附录:完整示例代码

# mps_tutorial_complete.py
"""
PyTorch MPS 完整教程示例代码
"""import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import timeclass MPSDemo:def __init__(self):# 设置设备self.device = self.get_device()print(f"使用设备: {self.device}")# 加载数据self.setup_data()# 创建模型self.setup_model()def get_device(self):"""自动选择可用设备"""if torch.cuda.is_available():return torch.device("cuda")elif torch.backends.mps.is_available():return torch.device("mps")else:return torch.device("cpu")def setup_data(self):"""准备数据"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 训练集self.trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)self.trainloader = DataLoader(self.trainset, batch_size=64, shuffle=True, num_workers=2)# 测试集self.testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)self.testloader = DataLoader(self.testset, batch_size=64, shuffle=False, num_workers=2)self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')def setup_model(self):"""初始化模型"""class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64 * 8 * 8, 512)self.fc2 = nn.Linear(512, 10)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.25)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 8 * 8)x = self.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return xself.model = SimpleCNN().to(self.device)self.criterion = nn.CrossEntropyLoss()self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)def train(self, epochs=5):"""训练模型"""self.model.train()start_time = time.time()for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(self.trainloader, 0):inputs, labels = data[0].to(self.device), data[1].to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs)loss = self.criterion(outputs, labels)loss.backward()self.optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()if i % 100 == 99:print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}, 'f'acc: {100.*correct/total:.2f}%')running_loss = 0.0end_time = time.time()print(f'训练完成,耗时: {end_time - start_time:.2f} 秒')def evaluate(self):"""评估模型"""self.model.eval()correct = 0total = 0with torch.no_grad():for data in self.testloader:images, labels = data[0].to(self.device), data[1].to(self.device)outputs = self.model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'测试集准确率: {accuracy:.2f}%')return accuracydef benchmark(self):"""性能基准测试"""def run_benchmark(device, size=1000, iterations=100):a = torch.randn(size, size, device=device)b = torch.randn(size, size, device=device)for _ in range(10):c = a @ bif device.type == 'mps':torch.mps.synchronize()start = time.time()for _ in range(iterations):c = a @ bif device.type == 'mps':torch.mps.synchronize()end = time.time()return end - startprint("\n性能对比测试:")cpu_time = run_benchmark(torch.device('cpu'))print(f"CPU 耗时: {cpu_time:.4f} 秒")if self.device.type != 'cpu':mps_time = run_benchmark(self.device)print(f"MPS 耗时: {mps_time:.4f} 秒")print(f"加速比: {cpu_time/mps_time:.2f}x")# 运行演示
if __name__ == "__main__":demo = MPSDemo()demo.benchmark()demo.train(epochs=3)demo.evaluate()

希望这篇教程能帮助你顺利开始使用 PyTorch MPS!