当前位置: 首页 > news >正文

Day41 Dataset和Dataloader

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块 import matplotlib.pyplot as plt # 设置随机种子,确保结果可复现 torch.manual_seed(42) # 1. 数据预处理,该写法非常类似于管道pipeline # transforms 模块提供了一系列常用的图像预处理操作 # 先归一化,再标准化 transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用 ]) # 2. 加载MNIST数据集,如果没有会自动下载 train_dataset = datasets.MNIST( root='./data', train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( root='./data', train=False, transform=transform ) import matplotlib.pyplot as plt # 随机选择一张图片,可以重复运行,每次都会随机选择 sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引 # len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字 image, label = train_dataset[sample_idx] # 获取图片和标签 # 示例代码 class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] def __getitem__(self, idx): return self.data[idx] # 创建类的实例 my_list_obj = MyList() # 此时可以使用索引访问元素,这会自动调用__getitem__方法 print(my_list_obj[2]) # 输出:30 class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] def __len__(self): return len(self.data) # 创建类的实例 my_list_obj = MyList() # 使用len()函数获取元素数量,这会自动调用__len__方法 print(len(my_list_obj)) # 输出:5 # minist数据集的简化版本 class MNIST(Dataset): def __init__(self, root, train=True, transform=None): # 初始化:加载图片路径和标签 self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签 self.transform = transform # 预处理操作 def __len__(self): return len(self.data) # 返回样本总数 def __getitem__(self, idx): # 获取指定索引的样本 # 获取指定索引的图像和标签 img, target = self.data[idx], self.targets[idx] # 应用图像预处理(如ToTensor、Normalize) if self.transform is not None: # 如果有预处理操作 img = self.transform(img) # 转换图像格式 # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化 return img, target # 返回处理后的图像和标签 # 可视化原始图像(需要反归一化) def imshow(img): img = img * 0.3081 + 0.1307 # 反标准化 npimg = img.numpy() plt.imshow(npimg[0], cmap='gray') # 显示灰度图像 plt.show() print(f"Label: {label}") imshow(image) # 3. 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关 shuffle=True # 随机打乱数据 ) test_loader = DataLoader( test_dataset, batch_size=1000 # 每个批次1000张图片 # shuffle=False # 测试时不需要打乱数据 )

@浙大疏锦行

http://www.jsqmd.com/news/150320/

相关文章:

  • 过量化导致精度下降?TensorRT补偿机制揭秘
  • 计算机Java毕设实战-基于JAVA的医院预约挂号管理系统的设计与实现基于Web的医院门诊在线预约挂号系统设计与实现【完整源码+LW+部署说明+演示视频,全bao一条龙等】
  • 智慧交通信号灯调控:城市大脑背后的推理引擎
  • springboot_ssm“云课堂”在线教育系统的设计与开发
  • 2025最新!9个AI论文工具测评:继续教育者必看的科研写作指南
  • 前端新人必看:IIFE到底解决了什么问题?(附实战技巧)
  • springboot_ssm“在云端”--在线音乐分享平台的设计与实现
  • 【毕业设计】基于JAVA的医院预约挂号管理系统的设计与实现(源码+文档+远程调试,全bao定制等)
  • 模型压缩终极形态:TensorRT + 知识蒸馏联合优化
  • 稀疏+量化双管齐下:极限压缩大模型体积
  • 2025最新!专科生必看9款AI论文工具测评与推荐
  • 横向对比测试:TensorRT vs OpenVINO vs TFLite
  • GitHub项目托管:公开示例代码促进传播
  • 黑客松比赛赞助:激发基于TensorRT的创新应用
  • 4次拷贝变0次:我用现代C++撸了个生产级零拷贝缓存
  • 2025年共创广告工厂标识系统深度解析:6S车间可视化、户外市政标识一体化解决方案权威推荐 - 品牌企业推荐师(官方)
  • 学校启用AIGC检测后,这十大降AI工具最稳
  • 2025年退火处理厂家权威推荐:南通汉科新能源领衔,五大退火工艺(完全/球化/去应力等)核心技术实力深度解析 - 品牌企业推荐师(官方)
  • SpringBoot-day01 学习心得
  • 十佳降AI工具实测,知网AIGC检测也能过
  • 冷启动问题解决:预加载TensorRT引擎提升首响速度
  • SpringBoot-day01-学习心得
  • 稀疏化支持进展:TensorRT如何利用结构化剪枝
  • Java计算机毕设之基于Springboot+Vue的电子商务订单管理系统设计与实现(完整前后端代码+说明文档+LW,调试定制等)
  • 论文降AI率工具排行榜:2025十佳推荐
  • 【毕业设计】基于springboot的校园二手交易平台(源码+文档+远程调试,全bao定制等)
  • Flask2入门开发详解
  • springboot_ssm“小饰界”线上饰品商城的设计与实现
  • 【效率工具】告别重复劳动!我开发了一个批量新建文件/文件夹工具
  • 提示词工程:与大模型高效对话的必备技能,程序员必学!