当前位置: 首页 > news >正文

别再只用MNIST了!Permuted/Split MNIST数据集实战:用PyTorch搭建你的第一个连续学习评估环境

突破传统MNIST:用PyTorch构建连续学习实战环境的完整指南

当大多数机器学习教程还在使用标准MNIST数据集时,前沿研究早已转向更具挑战性的变体。Permuted MNIST和Split MNIST不仅是学术论文中的常客,更是检验连续学习算法性能的黄金标准。本文将带你从零开始,用PyTorch搭建完整的连续学习评估环境,通过代码实践理解这两种经典数据集的精髓。

1. 为什么需要超越标准MNIST?

传统MNIST数据集作为机器学习入门教材已有二十余年历史,其简单性既是优点也是局限。784个像素点的固定排列方式让模型很快就能达到接近人类水平的准确率,但这恰恰掩盖了现实世界中的核心挑战——数据分布的动态变化。

连续学习研究需要能模拟以下场景的数据集:

  • 任务序列化:模型需要按顺序学习多个相关但不完全相同的任务
  • 灾难性遗忘:新任务学习时旧任务性能的下降程度需要量化
  • 知识迁移:先前学到的知识如何帮助后续任务学习

提示:Permuted MNIST通过像素重排改变输入分布,Split MNIST通过类别划分创建任务序列,两者分别对应Domain-IL和Class-IL场景

下表对比了三种MNIST变体的核心差异:

数据集类型变化维度适用场景挑战重点
标准MNIST无变化基础分类单一任务性能
Permuted MNIST像素排列Domain-IL分布变化适应
Split MNIST类别划分Class-IL类别增量学习

2. 环境准备与基础配置

开始前确保已安装最新版PyTorch和标准科学计算库:

import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms import numpy as np import matplotlib.pyplot as plt

定义基础多层感知机(MLP)模型,这是连续学习研究的常用架构:

class BaseMLP(nn.Module): def __init__(self, input_size=784, hidden_size=400, output_size=10): super(BaseMLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = x.view(x.size(0), -1) # 展平输入 x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) return self.fc3(x)

3. Permuted MNIST实战实现

Permuted MNIST的核心是为每个任务生成唯一的像素排列顺序。以下是关键实现步骤:

  1. 生成排列矩阵:为每个任务创建随机但固定的像素排列
  2. 数据转换器:应用排列并标准化图像
  3. 任务序列构建:创建多个不同排列的数据加载器
def generate_permutation(): """生成随机像素排列顺序""" return torch.randperm(784) class PermuteTransform: """应用像素排列的数据转换器""" def __init__(self, permutation): self.permutation = permutation def __call__(self, x): return x.view(-1)[self.permutation].view(1, 28, 28) # 创建5个不同排列的任务 num_tasks = 5 permutations = [generate_permutation() for _ in range(num_tasks)] task_loaders = [] for perm in permutations: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), PermuteTransform(perm) ]) train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transform), batch_size=64, shuffle=True) task_loaders.append(train_loader)

可视化不同排列下的图像样本:

def show_permuted_samples(permutations, num_samples=5): """展示不同排列下的MNIST样本""" fig, axes = plt.subplots(len(permutations), num_samples, figsize=(15, 10)) for i, perm in enumerate(permutations): transform = PermuteTransform(perm) for j in range(num_samples): img, _ = datasets.MNIST('../data', train=True)[j] axes[i,j].imshow(transform(img).squeeze(), cmap='gray') axes[i,j].axis('off') plt.show() show_permuted_samples(permutations[:3]) # 展示前3种排列

4. Split MNIST的精细实现

Split MNIST将10个数字类别划分为多个二元分类任务,典型划分方式如下:

  • Task 1: 识别0和1
  • Task 2: 识别2和3
  • Task 3: 识别4和5
  • Task 4: 识别6和7
  • Task 5: 识别8和9

实现关键点在于数据过滤和标签重映射:

def create_split_mnist_loaders(): """创建Split MNIST任务序列""" base_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 定义任务划分 (数字对列表) task_pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)] task_loaders = [] for pair in task_pairs: # 过滤只包含当前数字对的样本 def filter_func(data, target): mask = (target == pair[0]) | (target == pair[1]) return data[mask], target[mask] # 重映射标签为0/1 def remap_labels(target): return (target == pair[1]).long() # 自定义数据集 class SplitMNIST(torch.utils.data.Dataset): def __init__(self, train=True): self.mnist = datasets.MNIST('../data', train=train, download=True) self.data, self.targets = filter_func(self.mnist.data, self.mnist.targets) self.targets = remap_labels(self.targets) self.transform = base_transform def __len__(self): return len(self.data) def __getitem__(self, idx): img, target = self.data[idx], self.targets[idx] img = self.transform(img.numpy()) return img, target train_loader = torch.utils.data.DataLoader( SplitMNIST(train=True), batch_size=64, shuffle=True) task_loaders.append(train_loader) return task_loaders split_loaders = create_split_mnist_loaders()

5. 连续学习评估框架

完整的连续学习评估需要跟踪以下指标:

  • 当前任务准确率
  • 旧任务遗忘程度
  • 整体平均准确率

实现评估流程的核心代码:

def evaluate(model, task_id, test_loaders, device='cpu'): """评估模型在所有已学习任务上的表现""" model.eval() accuracies = [] with torch.no_grad(): for t in range(task_id + 1): correct = 0 total = 0 for images, labels in test_loaders[t]: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracies.append(100 * correct / total) return accuracies def continual_learning_training(model, task_loaders, num_epochs=5): """连续学习训练流程""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() all_accuracies = [] for task_id, train_loader in enumerate(task_loaders): print(f"\n=== Training on Task {task_id + 1} ===") for epoch in range(num_epochs): model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}") # 评估当前模型在所有已学习任务上的表现 current_acc = evaluate(model, task_id, task_loaders, device) all_accuracies.append(current_acc) print(f"Task Accuracies after Task {task_id+1}: {current_acc}") return all_accuracies

6. 高级技巧与优化方向

基础实现后,可以考虑以下增强方案:

对抗遗忘技术

  • 弹性权重固化(EWC):添加正则项保护重要参数
  • 记忆回放:保留少量旧任务样本进行联合训练
  • 动态架构:为每个任务分配专用模型组件
# EWC实现示例 def ewc_loss(model, fisher_matrix, previous_params, lambda_ewc): loss = 0 for name, param in model.named_parameters(): if name in fisher_matrix: loss += (fisher_matrix[name] * (param - previous_params[name])**2).sum() return lambda_ewc * loss # 在训练循环中添加EWC损失 total_loss = criterion(outputs, labels) + ewc_loss(model, fisher, prev_params, lambda_ewc=1000)

评估指标可视化

def plot_learning_curve(accuracies): """绘制连续学习曲线""" plt.figure(figsize=(10, 6)) for task in range(len(accuracies)): x = range(task + 1, len(accuracies) + 1) y = [acc[task] for acc in accuracies[task:]] plt.plot(x, y, marker='o', label=f'Task {task+1}') plt.xlabel('Task Number') plt.ylabel('Accuracy (%)') plt.title('Continual Learning Performance') plt.legend() plt.grid() plt.show()

实际项目中,我发现Permuted MNIST对初始化种子非常敏感,不同排列顺序可能导致性能波动达5-8%。解决方法是固定随机种子或进行多次实验取平均。Split MNIST则面临类别不平衡问题,某些数字对的样本量可能相差20%,需要在损失函数中添加类别权重。

http://www.jsqmd.com/news/764416/

相关文章:

  • 2025-2026美国移民机构深度测评:十大靠谱移民公司优势对比 - 品牌排行榜
  • PerfectDou:用完美信息蒸馏技术打造最强斗地主AI
  • EPPlus高级数据操作:使用LINQ和Lambda表达式处理Excel数据
  • 明日方舟智能基建管理工具:Arknights-Mower 完整使用指南
  • 告别重复造轮子:用快马AI为OpenClaw101项目生成高效开发工具集
  • Wan2.2-I2V-A14B WebUI汉化与定制:修改前端界面支持中文prompt友好输入
  • 从实验室到现场:高压设备绝缘距离怎么定?手把手教你理解“伏秒特性”与绝缘配合
  • MCP 2026边缘性能瓶颈诊断与突破(2024Q3最新FPGA+ARM异构部署实战手册)
  • PhoneGap Developer App部署与发布指南:Android、iOS、Windows Phone
  • 蓝桥杯嵌入式备赛:手把手教你搞定IIC驱动AT24C02和MCP4017(附完整代码)
  • 文案生成:从零开始的实用方法指南
  • 感定室外,孪生实时算\n \n纯视觉破局,孪生可测可控
  • 3个常见工作难题:如何用taskt零代码实现自动化突破?
  • Python 爬虫反爬突破:前端加密算法本地复现与调用
  • 昆山祥泽瑞:吴中专业的角钢批发有哪些 - LYL仔仔
  • 上海恩依餐饮:上海市家庭宴请推荐哪几家 - LYL仔仔
  • 量子催眠实施标准:软件测试从业者的意识探索指南
  • PC与智能手机出货量走势分化,AI浪潮下迷你主机线下遇冷线上待兴?
  • ComfyUI-WanVideoWrapper:AI视频生成的终极解决方案 - 从文本到视频的魔法变身
  • 2026年昆明代理记账服务深度指南:今非财税官方联系方式与行业横评 - 年度推荐企业名录
  • 【实战派×学院派】103|团队氛围消极,干活像交差,缺乏动力?
  • 还在手写policy.json?MCP 2026 2026.3版本已强制启用策略生命周期自动巡检,你的配置还能撑过下个季度吗?
  • 六西格玛成绩有效期多久? - 众智商学院官方
  • PostgreSQL 技术日报 (5月6日)|向量扩展新版本发布,内核并发机制迭代
  • M9A:重返未来1999终极自动化助手完整指南,三步实现游戏日常全托管
  • OBS高级计时器:为直播和视频制作提供精准时间管理
  • STM32 I2C LCD 1602驱动终极指南:3步实现嵌入式显示控制
  • 单图生成3D场景:NeRF技术革新与应用实践
  • 2026年昆明代理记账服务全生命周期深度横评与选购指南 - 年度推荐企业名录
  • 2026年昆明代理记账服务全景指南:五大品牌深度横评与企业选购宝典 - 年度推荐企业名录