news 2026/4/23 19:15:16

DAY 40 Dataset类和Dataloader类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
DAY 40 Dataset类和Dataloader类

一、Dataset类的_getitem_和_len_方法

在 PyTorch 中,torch.utils.data.Dataset 是所有自定义数据集的抽象基类,它规定了数据集必须实现两个核心方法:__len__ 和 __getitem__。这两个方法是 DataLoader 加载数据的基础,决定了数据集的 “大小” 和 “如何按索引取样本”。

Dataset 类的核心作用:Dataset 类的设计目标是封装数据集的逻辑(如数据读取、预处理、标签映射等),对外暴露统一的接口,让 DataLoader 可以无感地加载、批量处理、打乱数据。
自定义数据集时,必须继承 Dataset 并实现 __len__ 和 __getitem__(否则实例化会抛出 NotImplementedError)。

__getitem__ 方法
1. 核心作用
根据传入的索引 index,返回该索引对应的单个样本(通常是 “特征 + 标签” 的组合)。
DataLoader 会循环调用该方法(按索引取样本),并将多个样本拼接成批次,是数据加载的核心逻辑。
2. 实现规则

  • 入参:仅接收一个整数 index(范围:0 ≤ index < __len__());
  • 返回值:格式灵活,常见形式:

元组:(feature, label)(最常用);

字典:{"feature": feature, "label": label}(多模态 / 多特征场景更易读);

单个值:仅特征(无监督学习场景)。

3. 关键注意点

  • 索引合法性:DataLoader 通常会保证 index 在 [0, __len__()-1] 范围内,但自定义时建议避免越界;
  • 预处理逻辑:数据预处理(如归一化、图像裁剪、文本分词)建议放在该方法中(DataLoader 支持多进程加载,预处理并行执行效率更高);
  • 数据类型:返回的特征建议转为 torch.Tensor(方便后续模型计算),标签可根据需求保留 int/float 或转为 Tensor。

__len__ 方法
1. 核心作用
返回数据集的总样本数量,DataLoader 依赖该方法知道数据集的 “边界”,例如:

  • 计算迭代轮次(总样本数 / 批次大小);
  • 随机打乱时确定索引范围。

2. 实现规则

  • 无入参,仅返回一个非负整数;
  • 必须与数据集的实际样本数一致(否则会导致索引越界或数据加载不全)。

二、Dataloader类

DataLoader 核心作用:

  1. 自动按 batch_size 从 Dataset 中取多个样本,拼接成批次数据(如把多个 (feature, label) 拼接成 (batch_feature, batch_label));
  2. 支持数据打乱(shuffle),避免模型过拟合;
  3. 支持多进程加载(num_workers),提升数据读取效率(尤其适合大数据集 / 硬盘读取场景);
  4. 灵活的批次拼接逻辑(collate_fn),适配不同类型数据(如变长文本、多模态数据);
  5. 支持内存锁页(pin_memory),加速数据从 CPU 到 GPU 的传输。
参数名作用与说明默认值
dataset必须传入的 Dataset 实例(自定义 / 内置均可),DataLoader 基于它取样本
batch_size每个批次的样本数量1
shuffle是否在每个 epoch 开始时打乱数据索引(训练集建议 True,测试集建议 False)False
num_workers用于数据加载的子进程数(多进程加速);0 表示主进程加载0
  1. drop_last
若数据集总数不能被 batch_size 整除,是否丢弃最后一个不完整批次False
collate_fn自定义批次拼接函数,用于处理样本的拼接逻辑(如变长文本、自定义数据结构)None
pin_memory是否将加载的数据存入 CUDA 锁页内存(GPU 训练时设为 True,加速传输)False
timeout数据加载的超时时间(秒),防止子进程挂起0
sampler自定义索引采样策略(优先级高于 shuffle)None
batch_sampler自定义批次索引采样策略(与 batch_size/shuffle/sampler 互斥)None

DataLoader 工作原理:

  1. 索引生成:根据 Dataset.__len__() 获取总索引范围,结合 shuffle/sampler 生成索引序列;
  2. 批次切分:将索引序列按 batch_size 切分成多个批次索引(如 [0,1], [2,3], [4]);
  3. 样本读取:对每个批次的索引,调用 Dataset.__getitem__(index) 获取单个样本;
  4. 批次拼接:通过 collate_fn 将多个单个样本拼接成批次数据(默认拼接成 Tensor 矩阵);
  5. 多进程加速:num_workers > 0 时,子进程并行执行 “样本读取 + 预处理”,主进程仅负责拼接和分发。

核心结论

Dataset类:定义数据的内容和格式(即“如何获取单个样本”),包括:

- 数据存储路径/来源(如文件路径、数据库查询)。

- 原始数据的读取方式(如图像解码为PIL对象、文本读取为字符串)。

- 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过`transform`参数实现)。

- 返回值格式(如`(image_tensor, label)`)。

DataLoader类:定义数据的加载方式和批量处理逻辑(即“如何高效批量获取数据”),包括:

- 批量大小(batch_size)。

- 是否打乱数据顺序(shuffle)。

三、MNIST手写数字数据集

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块 import matplotlib.pyplot as plt # 设置随机种子,确保结果可复现 torch.manual_seed(42) # 1. 数据预处理,该写法非常类似于管道pipeline # transforms 模块提供了一系列常用的图像预处理操作 # 先归一化,再标准化 transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用 ]) # 2. 加载MNIST数据集,如果没有会自动下载 train_dataset = datasets.MNIST( root='./data', train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( root='./data', train=False, transform=transform )

作业

# 1. 导入必要库 import torch from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np # 2. 固定随机种子(可选,保证结果一致) torch.manual_seed(42) # 3. 定义数据预处理(CIFAR-10专用均值/标准差) # 说明:CIFAR-10的全局均值和标准差是行业公认值,标准化用 transform = transforms.Compose([ transforms.ToTensor(), # 转Tensor:把0-255的PIL图片→0-1的Tensor,维度[C, H, W](3,32,32) transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], # R/G/B三通道均值 std=[0.2470, 0.2435, 0.2616] # R/G/B三通道标准差 ) ]) # 4. 加载CIFAR-10数据集(自动下载) # 训练集 train_dataset = datasets.CIFAR10( root='./data', # 数据集保存路径 train=True, # 加载训练集(False则加载测试集) download=True, # 本地没有则自动下载 transform=transform # 应用预处理 ) # 5. 关键:提取单张图片并可视化 # 5.1 取数据集第0个样本(特征Tensor + 标签) img_tensor, label_idx = train_dataset[0] # img_tensor.shape = [3,32,32],label_idx是0-9的整数 print(f"图片Tensor形状:{img_tensor.shape}") # 输出:torch.Size([3, 32, 32]) print(f"图片标签索引:{label_idx}") # 输出:6(对应类别“青蛙”) # 5.2 定义CIFAR-10类别名称(对应索引0-9) cifar10_classes = [ '飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车' ] print(f"图片对应类别:{cifar10_classes[label_idx]}") # 输出:青蛙 # 5.3 预处理还原(因为Normalize后数值不在0-1,需要反归一化才能正常显示) # 反归一化公式:img = (img_tensor * std) + mean mean = np.array([0.4914, 0.4822, 0.4465]) std = np.array([0.2470, 0.2435, 0.2616]) # Tensor→numpy,维度从[C,H,W]→[H,W,C](matplotlib需要这个顺序) img_np = img_tensor.numpy().transpose((1, 2, 0)) img_np = img_np * std + mean # 反归一化 img_np = np.clip(img_np, 0, 1) # 确保数值在0-1之间(避免归一化后溢出) # 5.4 可视化图片 plt.figure(figsize=(4, 4)) # 设置图片大小 plt.imshow(img_np) # 显示图片 plt.title(f"Label: {cifar10_classes[label_idx]} (索引{label_idx})") plt.axis('off') # 隐藏坐标轴 plt.show()

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/23 12:47:12

鸿蒙应用冷启动优化:Flutter首屏秒开与白屏治理实战

前言&#xff1a;用户流失的“第一秒” 在鸿蒙应用开发中&#xff0c;启动速度是用户的第一印象。对于混合了Flutter的鸿蒙应用&#xff0c;常面临一个尴尬的场景&#xff1a;原生页面秒开&#xff0c;而包含Flutter的页面却有明显的延迟&#xff08;白屏或卡顿&#xff09;。…

作者头像 李华
网站建设 2026/4/23 14:27:43

1.15 并行编程

1.并行循环基本语法 2.并行循环原理 3.并行循环中的异常处理 4.停止 5.中断1.并行循环基本语法 C#中的Parallel类(位于 System.Threading.Tasks 命名空间)是.NET提供的并行编程核心工具, 旨在简化"数据并行"和 "任务并行"开发, 充分利用多核CPU资源, 避免手…

作者头像 李华
网站建设 2026/4/23 17:04:57

Unreal Engine文档查询太难?LobeChat快速定位

Unreal Engine文档查询太难&#xff1f;LobeChat快速定位 在开发一款基于 UE5 的开放世界游戏时&#xff0c;团队成员频繁遇到一个看似简单却异常耗时的问题&#xff1a;如何让角色正确跳跃&#xff1f;有人查蓝图节点&#xff0c;有人翻 C API 文档&#xff0c;还有人去论坛翻…

作者头像 李华
网站建设 2026/4/23 13:30:01

MeshLab文件格式完全指南:从入门到精通的实用技巧

MeshLab文件格式完全指南&#xff1a;从入门到精通的实用技巧 【免费下载链接】meshlab The open source mesh processing system 项目地址: https://gitcode.com/gh_mirrors/me/meshlab MeshLab作为开源的网格处理系统&#xff0c;其强大的文件格式支持能力是众多用户选…

作者头像 李华
网站建设 2026/4/23 5:41:15

15min的博客—回归的学习方法

15min的博客—回归的学习方法之前心态原因&#xff0c;对C语言的钻研有了一些中断&#xff0c;但现在&#xff0c;我又回来钻研了&#xff01;我想&#xff1a;怎样让我快速回忆一个星期前积累的知识呢&#xff1f;后来我决定&#xff1a;以“三子棋”一个大板块要求带我共同回…

作者头像 李华
网站建设 2026/4/23 5:40:37

瑞数6补环境案例(3)——吐环境脚本

【Bilibili】&#xff1a;餍足SATISFY 作者声明&#xff1a;文章仅供学习交流与参考&#xff01;严禁用于任何商业与非法用途&#xff01;否则由此产生的一切后果均与作者无关&#xff01;如有侵权&#xff0c;请联系作者本人进行删除&#xff01; 商业合作&#xff1a;yanzuk…

作者头像 李华