news 2026/6/10 12:47:50

Day 41 Dataset 与 DataLoader

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 41 Dataset 与 DataLoader

文章目录

  • Day 41 · Dataset 与 DataLoader
      • torchvision 模块速览
      • Step 1 · 定义 `transforms` 管道
    • 一、Dataset:定义“单份数据”
      • 1. 图片观察
      • 2. 两个必须的魔术方法
        • `__getitem__`:让对象支持索引
        • `__len__`:让对象支持 `len()`
      • 3. 自定义 `Dataset` 的伪代码
    • 二、DataLoader:批量调度器
    • 三、总结

Day 41 · Dataset 与 DataLoader

在训练大规模数据集时,显存通常无法一次性装下所有样本,因此必须按批次把数据送入模型。PyTorch 为此提供了两个密不可分的组件:

  1. Dataset:描述每一条数据长什么样、如何读取、是否需要预处理。
  2. DataLoader:负责把一个Dataset切成批次、决定是否乱序、是否并行加载。

下面以经典的MNIST 手写数字数据集为例(训练集 60k、测试集 10k、每张 28×28 灰度图),逐步梳理两者分工。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoader,Datasetfromtorchvisionimportdatasets,transformsimportmatplotlib.pyplotasplt# 为可复现性固定随机种子torch.manual_seed(42)
<torch._C.Generator at 0x77eb27fe76f0>

torchvision 模块速览

torchvision ├── datasets # 视觉数据集(如 MNIST、CIFAR) ├── transforms # 视觉数据预处理(裁剪、翻转、归一化等) ├── models # 各类预训练模型 ├── utils # 目标检测等常用工具函数 └── io # 图像 / 视频 IO

Step 1 · 定义transforms管道

transforms.Compose可以像数据管道一样串联多步操作,这里先把 PIL 图转成张量,再用 MNIST 的均值、方差做标准化。

# 1. 数据预处理,该写法非常类似于管道pipeline# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化transform=transforms.Compose([transforms.ToTensor(),# 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,),(0.3081,))# 标准化。MNIST数据集的均值和标准差,这个值很出名,所以直接使用])
# Step 2 · 加载数据集。如果本地没有,会自动下载到 ./datatrain_dataset=datasets.MNIST(root='./data',train=True,download=True,transform=transform)test_dataset=datasets.MNIST(root='./data',train=False,transform=transform)
100%|██████████| 9.91M/9.91M [00:07<00:00, 1.37MB/s] 100%|██████████| 28.9k/28.9k [00:00<00:00, 152kB/s] 100%|██████████| 1.65M/1.65M [00:00<00:00, 1.70MB/s] 100%|██████████| 4.54k/4.54k [00:00<00:00, 15.8MB/s]

PyTorch 的思路是:在“读取数据”这一环就完成预处理,因此transform直接写进datasets.MNIST的构造函数里。

一、Dataset:定义“单份数据”

  • 负责描述数据的来源、存储方式以及取出单个样本所需的所有步骤。
  • 必须能够在索引访问时返回(features, target),并能报告自身的长度。

1. 图片观察

Dataset实例支持下标操作,因此可以像访问列表一样通过索引获取单张图像及其标签。

# 随机选择一张图片,可以重复运行,每次都会随机选择sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()# 随机选择一张图片的索引# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字image,label=train_dataset[sample_idx]# 获取图片和标签

2. 两个必须的魔术方法

torch.utils.data.Dataset是一个抽象基类,自定义数据集需要重写:

  • __len__():返回样本总数,供len(dataset)或 DataLoader 计算迭代次数。
  • __getitem__(idx):根据索引返回单个样本(通常是(data, label))。

因此train_dataset[sample_idx]才会得到(image, label)

__getitem__:让对象支持索引
classMyList:# 仅为解释魔术方法的示例def__init__(self):self.data=[10,20,30,40,50]def__getitem__(self,idx):returnself.data[idx]my_list_obj=MyList()print(my_list_obj[2])# 输出 30
30
__len__:让对象支持len()
classMyList:def__init__(self):self.data=[10,20,30,40,50]def__len__(self):returnlen(self.data)my_list_obj=MyList()print(len(my_list_obj))# 输出 5
5

3. 自定义Dataset的伪代码

常见写法是在构造函数中读入路径或内存数据,并保存transform__getitem__返回预处理后的样本。

classMNIST(Dataset):def__init__(self,root,train=True,transform=None):self.data,self.targets=fetch_mnist_data(root,train)# 假设这里完成原始数据读取self.transform=transformdef__len__(self):returnlen(self.data)def__getitem__(self,idx):img,target=self.data[idx],self.targets[idx]ifself.transformisnotNone:img=self.transform(img)returnimg,target
组件职责关键方法
Dataset1. 存储数据和标签的映射关系
2. 定义单样本的获取方式
3. 应用样本级预处理(如缩放、裁剪)
__getitem__(idx)
__len__()
DataLoader1. 批量组织样本
2. 并行加载数据
3. 打乱数据顺序
4. 处理多进程问题
迭代器接口(iter()next()
  • 可以把Dataset 想成“厨师”:负责挑选食材、清洗、调味(预处理)。
  • DataLoader 则像“服务员”:按订单把菜(批次)端给模型。
defimshow(img):img=img*0.3081+0.1307# 反标准化回原始像素npimg=img.numpy()plt.imshow(npimg[0],cmap='gray')plt.axis('off')plt.show()print(f"Label:{label}")imshow(image)
Label: 6

二、DataLoader:批量调度器

DataLoader 根据我们提供的Dataset产出一个可迭代对象,它负责:

  • batch_size聚合样本;
  • 根据shuffle决定是否随机打乱顺序;
  • 通过num_workers控制并行加载进程数。
train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=1000# 测试集通常不需要 shuffle)

三、总结

维度DatasetDataLoader
核心职责定义“数据是什么”以及如何得到单个样本决定怎样批量、按顺序或乱序地取数据
核心接口__getitem____len__通过参数控制加载逻辑,无需继承
预处理__getitem__transform中完成不做预处理,直接消费 Dataset 的输出
并行能力单线程读取num_workers>0时可多进程读取
典型参数roottransformtarget_transformbatch_sizeshufflenum_workers

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/10 14:59:24

京东自动化脚本终极指南:5分钟搭建你的智能签到系统

京东自动化脚本终极指南&#xff1a;5分钟搭建你的智能签到系统 【免费下载链接】jd_scripts-lxk0301 长期活动&#xff0c;自用为主 | 低调使用&#xff0c;请勿到处宣传 | 备份lxk0301的源码仓库 项目地址: https://gitcode.com/gh_mirrors/jd/jd_scripts-lxk0301 还在…

作者头像 李华
网站建设 2026/6/10 15:06:29

WebForms 事件

WebForms 事件 引言 WebForms 是微软在 .NET 框架中提供的一种用于构建动态网页的技术。在 WebForms 开发中,事件处理是至关重要的。本文将深入探讨 WebForms 事件的概念、类型、生命周期以及如何进行事件处理,旨在帮助开发者更好地理解和应用这一技术。 什么是 WebForms …

作者头像 李华
网站建设 2026/6/10 7:31:21

SQL FOREIGN KEY

SQL FOREIGN KEY 在数据库设计中,FOREIGN KEY 是一种非常重要的约束,它用于保证数据库表之间的引用完整性。本文将详细介绍 SQL 中的 FOREIGN KEY 约束,包括其定义、作用、语法以及在实际应用中的注意事项。 一、什么是 FOREIGN KEY? FOREIGN KEY 是一种关系型数据库约束…

作者头像 李华
网站建设 2026/6/10 16:25:16

3步搞定BetterNCM插件:让你的网易云音乐脱胎换骨

3步搞定BetterNCM插件&#xff1a;让你的网易云音乐脱胎换骨 【免费下载链接】BetterNCM-Installer 一键安装 Better 系软件 项目地址: https://gitcode.com/gh_mirrors/be/BetterNCM-Installer 还在忍受网易云音乐单调的界面和有限的功能吗&#xff1f;BetterNCM插件正…

作者头像 李华
网站建设 2026/6/9 20:15:06

研发OKR的制定方法

制定研发&#xff08;R&D&#xff09;团队的OKR&#xff08;Objectives and Key Results&#xff09;&#xff0c;是企业管理实践中的一项“高难度”挑战。其核心难点在于如何平衡“研发的探索性”与“业务的确定性”。研发OKR的制定&#xff0c;其核心方法论是实现两大转变…

作者头像 李华
网站建设 2026/6/9 21:46:44

MVC 控制器:架构的核心与实现

MVC 控制器:架构的核心与实现 引言 在软件开发领域,MVC(Model-View-Controller)架构模式是一种广泛采用的设计模式,它将应用程序分为三个核心组件:模型(Model)、视图(View)和控制器(Controller)。控制器作为MVC架构中的核心,负责处理用户输入、更新模型和选择视…

作者头像 李华