当前位置: 首页 > news >正文

深度学习框架PyTorch笔记(三)数据集类(Data Set)与数据加载器(Data Loader)

深度学习框架PyTorch笔记(三)数据集类(Data Set)与数据加载器(Data Loader)

​ 在PyTorch中,数据集(Data Set)和数据加载器(Data Loader)是实现深度学习模型和测试的基本组件。下面将首先介绍数据集(Data Set)和数据加载器(Data Loader)的概念,然后介绍如何创建和使用PyTorch中的数据加载器的一些步骤和示例。

数据集类(Data Set)是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。

数据加载器(Data Loader)是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。

​ PyTorch中的 torch.utils.data.Datasettorch.utils.data.DataLoader 是数据加载和处理的核心组件。它们将数据读取与模型训练解耦,提供高效、灵活的数据迭代方式。下面从基础概念、自定义加载器参数、多进程机制等方面进行详细介绍。

1.数据集(Data Set)

1.1 自定义数据集定义实现

Data Set 是一个抽象类,表示一个数据集。任何自定义数据集都必须继承它,自定义DataSet类,必须实现它构造函数和两个方法:

  • __init__: 在 实例化DataSet 对象运行一次。我们初始化包含图像的目录、注释文件和transform与 target_transform.

  • __len__:返回数据集的总样本数。len(dataset)会调用它。

  • __getitem__(self, idx):根据整数索引idx会返回一个样本(通常为特征和标签)。dataset[idx] 会调用它。

其作用就是实现通过索引访问对应的数据以及标签

from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]

使用自定义数据集时,可以用将其与torch.utils.data.DataLoader 结合使用,以便进行数据的批量加载和处理和训练。

1.2 两种自定义数据集风格

​ 在PyTorch中,自定义数据集有两个核心设计模式:映射式(Map-Style)可迭代式(Iterable-style) 。它们的差异不仅是实现接口不同,更反映了“随机访问”与“流式读取”两种数据消费范式的根本区别。下面从设计理念、实现细节、多进程交互、适用场景等方面深入解析。

  • Map-style datasets(映射式):就是上述需要实现 __getitem____len__ 的数据集,它通过索引映射到数据样本。适用于所有数据能一次性放入索引结构(如列表、文件路径列表)的场景。
  • Iterable-style datasets(可迭代式):当数据集太大无法一次性加载,或数据是流式读取时(如实时日志、数据库流),可以继承 IterableDataset,实现 __iter__ 方法返回一个迭代器。这种数据集不能使用 len(),也无法使用随机采样(shuffle)的 loader,需使用 Sampler 的特定变体。

在后续笔记我们将详细介绍。

1.3 内置数据集

​ PyTorch提供了一些常用数据集类,主要在torchvision.datasetstorchtext.datasetstorchaudio.datasets中。例如:

  • torchvision.datasets.MNISTCIFAR10ImageFolder(从文件夹结构加载图片,子文件夹为类别)
  • torchtext.datasets.IMDB
  • torchaudio.datasets.LIBRISPEECH

这些内置类都继承自 Dataset,使用时可自动下载数据,并提供标准化访问方式。

​ 现在我们来展示一个如何从TorchVision加载了Fahion-MINIST由60000个训练样本和10000个测试样本组成。每个样本包含一个\(28\times{28}\) 灰度图像和一个来自10个类别之一的关联标签。下面使用以下参数加载FashionMINIST数据集:

  • root:是存储路径、测试数据的路径。
  • train:指定训练集或测试数据集。
  • download=True:如果root路径下没有数据,则从网上下载数据。
  • transformtarget_transform是指定特征和标签转换。
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="./data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="./data",train=False,download=True,transform=ToTensor()
)

我们可以用索引来访问数据集中的样本,用 matplotlib 可视化图形样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx]figure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

其运行结果如下:

Figure_1

2. 数据加载器(Data Loader)

数据加载器(Data Loader)DataSet 封装为可迭代对象,负责批量加载、打乱数据、多进程并行加载等功能。其功能如下:

  • 批量加载数据:DataLoader可以从数据集中按照指定的批量大小加载数据。每个批次的数据可以作为一个张量或列表返回,便于进行后续的处理和训练。
  • 数据随机洗牌:通过设置shuffle=True,DataLoader可以在每个迭代周期中对数据进行随机洗牌,以减少模型对数据顺序的依赖性,提高训练效果。
  • 多线程数据加载:DataLoader支持使用多个线程来并行加载数据,加快数据加载的速度,提高训练效率。
  • 数据批次采样:除了按照批量大小加载数据外,DataLoader还支持自定义的数据批次采样方式。可以通过设置batch_sampler参数来指定自定义的批次采样器,例如按照指定的样本顺序或权重进行采样。

数据加载器的API形式核心参数

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None, multiprocessing_context=None,generator=None, prefetch_factor=2, persistent_workers=False)
  • dataset:要加载的 Dataset 对象(映射式或可迭代式)。
  • batch_size:每个批次的样本数,默认为 1。
  • shuffle:是否在每个 epoch 开始时打乱数据顺序(仅对映射式有效)。打乱基于 RandomSampler
  • sampler:自定义采样器,继承自 torch.utils.data.Sampler。定义数据索引的抽取策略。如果指定,shuffle 必须为 False
  • batch_sampler:类似 sampler,但每次返回一批索引,与 batch_sizeshufflesampler 互斥。
  • num_workers:用于数据加载的子进程数。0 表示在主进程中加载,通常设置大于 0 可以加速数据预处理,利用多核。
  • collate_fn:函数,定义如何将多个样本列表合并为一个批次。默认 collate_fn 会将所有样本沿第0维堆叠成张量,通常对于同型数据有效。如果样本结构不一致(如不同长度序列),需要自定义。
  • pin_memory:若为 True,数据加载器在返回张量前将其复制到 CUDA 固定内存,加速数据传输到 GPU。仅适用于 CUDA。
  • drop_last:若为 True,丢弃最后一个不完整批次(当总样本数不能被 batch_size 整除时)。在训练时如果要求严格固定批次大小(如 BatchNorm)应设为 True
  • timeout:从 worker 进程获取一个 batch 的超时时间(秒)。如果超时会抛异常。
  • worker_init_fn:每个 worker 进程的初始化函数,参数为 worker id,可用于设置随机种子等。
  • generator:用于生成随机采样的伪随机数生成器,保证可复现性。
  • prefetch_factor:每个 worker 预先加载的 batch 数(默认 2),增加可以让 GPU 更少等待。
  • persistent_workers:若为 True,在数据集被消费一次后不会关闭 worker 进程,可保持 worker 存活以加速后续 epoch。

数据调用案例Demo

import torch
from torch.utils.data import Dataset, DataLoader# 自定义数据集类
class MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]# 自定义数据加载器类
class MyDataLoader(DataLoader):def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0):super().__init__(dataset, batch_size, shuffle, num_workers=num_workers)def collate_fn(self, batch):# 自定义的数据预处理、合并等操作# 这里只是简单地将样本转换为Tensor,并进行堆叠return torch.stack(batch)# 自定义数据集类
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)# 创建数据加载器实例
dataloader = MyDataLoader(dataset, batch_size=2, shuffle=True)# 遍历数据加载器
for batch in dataloader:# batch是一个包含多个样本的张量(或列表)# 这里可以对批次数据进行处理print(batch)

3.实战案例

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):x = torch.tensor(load_iris().data)y = torch.tensor(load_iris().target)# 数据归一化x_min = torch.min(x, dim=0).valuesx_max = torch.max(x, dim=0).valuesx = (x - x_min) / (x_max - x_min)if shuffle:idx = torch.randperm(x.shape[0])x = x[idx]y = y[idx]return x, y# 自定义鸢尾花数据类
class IrisDataset(Dataset):def __init__(self, mode='train', num_train=120, num_dev=15):super(IrisDataset, self).__init__()x, y = load_data(shuffle=True)if mode == 'train':self.x, self.y = x[:num_train], y[:num_train]elif mode == 'dev':self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]else:self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]def __getitem__(self, idx):return self.x[idx], self.y[idx]def __len__(self):return len(self.x)batch_size = 16# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

4.总 结

  • ataset 定义数据源及其访问方式,映射式最常用,流式数据用 IterableDataset
  • DataLoader 封装采样、批处理、多进程加载和内存固定等功能,参数丰富。
  • 通过自定义 samplercollate_fn 可以灵活处理各种数据形式和不平衡问题。
  • 多进程加载是加速训练的关键,需注意内存复制和系统兼容性。

掌握 DatasetDataLoader 的用法与内部机制,能够让你根据实际需求搭建高效的数据管道,将 I/O 瓶颈降到最低,从而充分释放 GPU 计算能力。

5.参考资料

https://cloud.tencent.com/developer/article/2055224?policyId=1003

https://cloud.tencent.com/developer/article/2440506

https://cloud.tencent.com/developer/article/1010379?policyId=1003

http://www.jsqmd.com/news/971160/

相关文章:

  • JAVA:继承
  • m4s-converter:三步解决B站缓存视频无法播放的终极方案
  • 西安 GEO 优化服务商深度解析:服务商选择核心原因分析
  • 如何用开源工具实现高效图片管理:5步打造个人视觉搜索引擎
  • Play Integrity Checker实战指南:轻松构建Android设备安全验证
  • ibbot角色智能体 v2.0 升级公告:全新上下文限制功能上线————灵活适配速度与深度,让每位数字伙伴更懂你
  • 抖音视频下载架构解析:异步批量处理与无水印技术实现
  • 2026 年 GEO 公司推荐指南:技术与合规双轮驱动下的 Top5 企业解析 - GEO优化
  • 系统架构设计师-从 PDR到 WPDRRC 的模型演进与架构实践
  • 记录跨境独立站 海外VPS组合落地的一线实操动态与调研手记
  • 2026 郑州防水补漏服务商口碑测评榜单|全屋渗漏维修机构优选指南(6 月最新) - 宅安选房屋修缮
  • J4125 安装 OPNsense
  • 算法不稳定,则就希望环境稳定
  • 12700黄大年茶思屋榜文第127期 | 鸿蒙领域前沿技术难题抽取篇
  • 第3课:开发环境全套搭建【Python环境、LangChain、LangSmith依赖安装与全局配置】
  • 开源自动化工具新范式:如何用LCU API构建你的英雄联盟技术助手
  • 小语言模型(SLM)技术深度解析:从剪枝蒸馏到端侧推理的轻量化 AI 全栈技术
  • 如何在本地电脑上实现千万级图片秒级搜索:完整免费指南
  • 佛山搬家公司哪家强?大件搬迁运输实力见证 - 从来都是英雄出少年
  • 梳理中小出海独立站落地阶段关于WordPress 海外主机的实操参考路径
  • 小红书全自动发表评论基本完成
  • 暗黑破坏神2存档编辑器d2s-editor:从零开始掌握游戏数据可视化修改
  • 2026年高口碑GEO优化服务商精选:五家企业的核心技术能力经受考验 - GEO优化
  • Oops Framework-7-由空项目创建Oops Framework项目
  • 解锁第三方鼠标的全部潜能:Mac Mouse Fix 让你的普通鼠标秒变生产力神器
  • 3分钟解锁B站缓存视频的终极免费解决方案:m4s-converter完整指南
  • 流量不够用怎么办?作为女生我真的很烦这件事!终于找到低月租大流量卡了,19元起,运营商直发 - 172号卡
  • Discord消息批量清理技术深度解析:Undiscord实现机制详解
  • 5分钟学会使用免费在线法线贴图生成器,让3D模型细节飙升300%!
  • 跨视域融合感知技术,搭建口岸通关智能顶级视频孪生系统