news 2026/4/23 12:49:31

Day41 Dataset和Dataloader

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day41 Dataset和Dataloader
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块 import matplotlib.pyplot as plt # 设置随机种子,确保结果可复现 torch.manual_seed(42) # 1. 数据预处理,该写法非常类似于管道pipeline # transforms 模块提供了一系列常用的图像预处理操作 # 先归一化,再标准化 transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用 ]) # 2. 加载MNIST数据集,如果没有会自动下载 train_dataset = datasets.MNIST( root='./data', train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( root='./data', train=False, transform=transform ) import matplotlib.pyplot as plt # 随机选择一张图片,可以重复运行,每次都会随机选择 sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引 # len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字 image, label = train_dataset[sample_idx] # 获取图片和标签 # 示例代码 class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] def __getitem__(self, idx): return self.data[idx] # 创建类的实例 my_list_obj = MyList() # 此时可以使用索引访问元素,这会自动调用__getitem__方法 print(my_list_obj[2]) # 输出:30 class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] def __len__(self): return len(self.data) # 创建类的实例 my_list_obj = MyList() # 使用len()函数获取元素数量,这会自动调用__len__方法 print(len(my_list_obj)) # 输出:5 # minist数据集的简化版本 class MNIST(Dataset): def __init__(self, root, train=True, transform=None): # 初始化:加载图片路径和标签 self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签 self.transform = transform # 预处理操作 def __len__(self): return len(self.data) # 返回样本总数 def __getitem__(self, idx): # 获取指定索引的样本 # 获取指定索引的图像和标签 img, target = self.data[idx], self.targets[idx] # 应用图像预处理(如ToTensor、Normalize) if self.transform is not None: # 如果有预处理操作 img = self.transform(img) # 转换图像格式 # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化 return img, target # 返回处理后的图像和标签 # 可视化原始图像(需要反归一化) def imshow(img): img = img * 0.3081 + 0.1307 # 反标准化 npimg = img.numpy() plt.imshow(npimg[0], cmap='gray') # 显示灰度图像 plt.show() print(f"Label: {label}") imshow(image) # 3. 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关 shuffle=True # 随机打乱数据 ) test_loader = DataLoader( test_dataset, batch_size=1000 # 每个批次1000张图片 # shuffle=False # 测试时不需要打乱数据 )

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/28 1:01:55

开源中国报道申请:获得官方渠道背书

NVIDIA TensorRT:解锁深度学习推理性能的关键引擎 在人工智能应用加速落地的今天,一个训练得再完美的模型,如果无法在生产环境中快速、稳定地响应请求,其价值就会大打折扣。尤其是在视频监控、语音交互、推荐系统等对延迟敏感的场…

作者头像 李华
网站建设 2026/4/18 11:19:54

性能回归测试:持续验证TensorRT优化稳定性

性能回归测试:持续验证TensorRT优化稳定性 在自动驾驶的感知系统中,一个目标检测模型从实验室准确率提升1%到实际路测时推理延迟增加30毫秒——这足以让车辆错过关键避障时机。这种“精度换性能”的隐性代价,正是AI工程化落地中最危险的暗礁。…

作者头像 李华
网站建设 2026/4/23 11:29:49

测试《A Simple Algorithm for Fitting a Gaussian Function》拟合

https://github.com/JohannesMeyersGit/1D-Gaussian-Fitting/blob/main/Itterativ_1D_Gaussian_Fit.py 源码 每次迭代采样不同子区间,error(拟合的均值-实际均值) 先减低后增,改成样本点不变 error 曲线看上去正常,但是 A 的值离实际越来越大&#xff…

作者头像 李华
网站建设 2026/4/19 23:40:02

华为全联接大会演讲:跨厂商合作可能性探索

华为全联接大会演讲:跨厂商合作可能性探索 在AI模型日益复杂、部署场景愈发多样的今天,一个现实问题正摆在所有硬件与系统厂商面前:如何让训练好的深度学习模型,在不同品牌、不同架构的设备上都能高效运行?尤其是在华为…

作者头像 李华
网站建设 2026/4/22 11:31:44

GitHub项目托管:公开示例代码促进传播

GitHub项目托管:公开示例代码促进传播 在今天的AI工程实践中,一个训练得再完美的深度学习模型,如果无法高效部署到生产环境,其价值就会大打折扣。尤其是在视频分析、自动驾驶、实时推荐等对延迟敏感的场景中,推理速度往…

作者头像 李华
网站建设 2026/4/17 13:52:57

轻量化ssh工具Dropbear 介绍与使用说明

一、Dropbear 是什么? Dropbear 是一个开源、轻量级的 SSH 服务器和客户端实现,主要特点是: 体积小:比 OpenSSH 小很多,非常适合嵌入式设备、路由器、单板机(如 OpenWrt、树莓派精简系统)等。功…

作者头像 李华