当前位置: 首页 > news >正文

Day 38 - Dataset 和 DataLoader

在深度学习任务中,数据处理是至关重要的一环。面对大规模数据集,显存往往无法一次性存储所有数据,因此需要采用分批训练(Batch Training)的策略。PyTorch 提供了两个核心工具类来解决数据加载和预处理的问题:DatasetDataLoader

本文将深入探讨这两个类的原理、用法以及它们之间的关系,并以经典的 MNIST 手写数字数据集为例进行演示。

一、 PyTorch 数据处理核心架构

在 PyTorch 中,数据处理流程被解耦为两个独立的部分:

  1. Dataset (数据集):负责定义“数据是什么”,即如何获取单个样本及其对应的标签,以及如何进行预处理。
  2. DataLoader (数据加载器):负责定义“如何加载数据”,即如何将 Dataset 中的样本组装成批次(Batch),并提供多线程加载、随机打乱等功能。

形象比喻

  • Dataset就像是厨师,他的工作是负责把每一个菜品(样本)切好、洗好、调好味(预处理)。
  • DataLoader就像是服务员,他的工作是把厨师做好的菜品,按照订单的要求(Batch Size),打包好端给客人(模型)。

二、 Dataset 类详解

torch.utils.data.Dataset是一个抽象基类,所有自定义的数据集都必须继承它,并实现其核心接口。

1. 核心魔术方法

PyTorch 要求 Dataset 子类必须实现以下两个魔术方法(Magic Methods):

  • __len__(self)
    • 作用:返回数据集的样本总数。
    • 调用方式:当使用len(dataset)时自动调用。
    • 意义:DataLoader 需要知道数据集的大小,以便计算一个 Epoch 需要多少个 Batch。
  • __getitem__(self, idx)
    • 作用:根据索引idx获取单个样本的数据和标签。
    • 调用方式:当使用dataset[idx]时自动调用。
    • 意义:这是数据读取和预处理发生的地方。

2. Python 魔术方法原理解析

为了更好地理解__len____getitem__,我们来看一个简单的 Python 自定义类示例:

class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] # 实现索引访问功能 def __getitem__(self, idx): return self.data[idx] # 实现长度获取功能 def __len__(self): return len(self.data) # 实例化对象 my_list_obj = MyList() # 1. 测试 __getitem__ # 对象可以直接使用 [] 索引访问,像内置列表一样 print(f"索引为2的元素: {my_list_obj[2]}") # 输出: 30 # 2. 测试 __len__ # 对象可以直接使用 len() 函数 print(f"列表长度: {len(my_list_obj)}") # 输出: 5

3. 自定义 Dataset 示例

基于上述原理,一个典型的自定义 Dataset 结构如下:

from torch.utils.data import Dataset class MNIST(Dataset): def __init__(self, root, train=True, transform=None): """ 初始化:加载文件路径、标签文件等 """ # 假设 fetch_mnist_data 是一个自定义函数,用于读取数据 self.data, self.targets = fetch_mnist_data(root, train) self.transform = transform # 预处理操作流水线 def __len__(self): """ 返回数据集大小 """ return len(self.data) def __getitem__(self, idx): """ 获取指定索引 idx 的样本 """ # 1. 根据索引获取原始数据和标签 img, target = self.data[idx], self.targets[idx] # 2. 应用预处理(如转 Tensor、归一化等) if self.transform is not None: img = self.transform(img) return img, target

三、 实战:使用 torchvision 加载 MNIST

torchvision是 PyTorch 官方的计算机视觉库,其中torchvision.datasets模块内置了许多常用数据集(如 MNIST, CIFAR10, ImageNet 等),它们都已经实现了 Dataset 的接口。

1. 数据预处理 (Transforms)

在加载图像数据时,通常需要进行一系列预处理,如转为张量(Tensor)、归一化(Normalize)等。

from torchvision import transforms # 定义预处理流水线 transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为 PyTorch 张量,并将像素值归一化到 [0, 1] transforms.Normalize((0.1307,), (0.3081,)) # 标准化:(x - mean) / std。参数为 MNIST 数据集的全局均值和标准差 ])

2. 加载数据集

from torchvision import datasets # 加载训练集 train_dataset = datasets.MNIST( root='./data', # 数据存储路径 train=True, # True 表示加载训练集 download=True, # 如果路径下不存在数据,是否自动下载 transform=transform # 应用上面定义的预处理 ) # 加载测试集 test_dataset = datasets.MNIST( root='./data', train=False, # False 表示加载测试集 transform=transform )

注意:在 PyTorch 的设计哲学中,数据预处理通常是在加载阶段(即__getitem__被调用时)动态进行的,而不是先处理好再保存。这样做可以节省磁盘空间,并支持动态的数据增强。

3. 查看单个样本

由于train_dataset本质上是一个 Dataset 子类,我们可以直接通过索引访问:

import matplotlib.pyplot as plt import torch # 随机获取一个索引 sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 获取样本(自动调用 __getitem__) image, label = train_dataset[sample_idx] # 可视化(需要反归一化以便人眼观察) def imshow(img): img = img * 0.3081 + 0.1307 # 反标准化 npimg = img.numpy() plt.imshow(npimg[0], cmap='gray') plt.show() print(f"Label: {label}") imshow(image)

四、 DataLoader 类详解

torch.utils.data.DataLoader是 PyTorch 中用于加载数据的核心工具。它接收一个 Dataset 对象,并根据配置参数生成一个可迭代对象。

1. 核心功能

DataLoader 的主要职责包括:

  • Batching:将多个样本打包成一个批次。
  • Shuffling:在每个 Epoch 开始时打乱数据顺序,防止模型记忆数据的顺序特征。
  • Multiprocessing:使用多进程并行加载数据,加速数据准备过程(避免 CPU 成为瓶颈)。

2. 创建 DataLoader

from torch.utils.data import DataLoader # 训练集加载器 train_loader = DataLoader( train_dataset, batch_size=64, # 每个批次包含 64 个样本 shuffle=True # 训练时通常需要打乱数据 ) # 测试集加载器 test_loader = DataLoader( test_dataset, batch_size=1000, # 测试时显存压力较小,可以使用更大的 batch_size shuffle=False # 测试时不需要打乱顺序,以便结果对比 )

关于 Batch Size 的选择

通常选择 2 的幂次方(如 32, 64, 128),这有利于 GPU 的并行计算效率。

五、 总结:Dataset 与 DataLoader 的对比

为了清晰地区分这两个概念,我们可以从以下几个维度进行对比:

维度

Dataset

DataLoader

核心职责

定义“数据内容”和“单个样本获取方式”

定义“批量加载策略”和“迭代方式”

核心方法

__getitem__(获取单个),__len__(总数)

内部实现迭代器协议 (__iter__)

预处理位置

__getitem__中定义具体的转换逻辑

不负责预处理,直接使用 Dataset 返回的结果

并行处理

无(仅处理单样本逻辑)

支持多进程加载 (num_workers)

关键参数

root(路径),transform(变换)

batch_size,shuffle,num_workers

一句话总结

Dataset负责把数据从磁盘读出来并处理成模型能看懂的格式(Tensor),而DataLoader负责把这些 Tensor 批量、高效、随机地喂给模型进行训练。

http://www.jsqmd.com/news/84474/

相关文章:

  • 数据链路层复习总结
  • 高中数学
  • Level 1 → Level 2
  • 如何快速掌握Hyperion安卓调试工具:完整入门指南
  • 【Spring框架】SpringMVC基本原理与配置
  • 地理信息与地图行业的新机会:从地图到空间智能
  • openEuler入门学习教程,从入门到精通,openEuler 24.03 中的 Vim 编辑器 —— 全面知识点详解(7) - 指南
  • Emotn TV桌面修改版:三版本满足不同需求,优化时间天气显示与系统性能
  • 中国独立开发者创业实战指南:从技术到商业的变现路径
  • eHR品牌TOP5年度榜单公布!HR系统/HR管理系统市场主流公司推荐 - 全局中转站
  • 32、Django Web 应用开发实战指南
  • 24、Python在多操作系统及云计算环境中的应用
  • JavaScript 在 WebAssembly 时代的角色转变:作为 Wasm 模块编排层与高性能计算逻辑的共存模式研究
  • JavaScript 语言特性的未来演进:探讨可插拔语法扩展(Macros)对前端工具链(Babel/SWC)的底层重构潜力
  • 2022年TRC SCI1区TOP,基于随机分形搜索算法的多无人机四维航迹优化自适应冲突消解方法,深度解析+性能实测
  • 《智能世界2035》——华为预测十年以后智能世界的模样
  • 【Ubuntu】『You are in emergency mode, After logging in, type “journalctl -xb“ to view system logs,...』
  • 【编程和大模型交互】
  • 卷积神经网络中的自适应池化
  • RS-fMRI统计分析及作图入门
  • 如何快速掌握Flutter广告集成:GroMore实战全解
  • 全排列问题(包含重复数字与不可包含重复数字)
  • 纯电动汽车Matlab Simulink仿真模型构建与实现:全面集成电机模型、电池模型、变速器...
  • 基于TTC触发的车辆换道轨迹规划与控制:五次多项式实时规划及Matlab与CarSim联合仿真实验
  • 深入理解 Google Wire:Go 语言的编译时依赖注入框架
  • C++学习之旅【C++类和对象(下)】
  • 格子波尔兹曼LBM在甲烷吸附解吸研究中的应用及文献复现
  • 从零构建大模型智能体:OpenAI Function Calling智能体实战
  • 基于定子磁场矢量控制的异步电机磁链观测模型研究与应用
  • 光伏充电站的“弹性“密码:当电动车遇上数学建模