您现在的位置是:首页 >技术杂谈 >关于PyTorch中的 torch.utils.data.DataLoader网站首页技术杂谈

关于PyTorch中的 torch.utils.data.DataLoader

许野平 2024-06-18 00:01:02
简介关于PyTorch中的 torch.utils.data.DataLoader

一、DataLoader

torch.utils.data.DataLoader是PyTorch中一个用于数据加载的工具类,主要用于将样本数据划分为多个小批次(batch),以便进行训练、测试、验证等任务。该类支持多线程异步数据读取和数据预处理,使得模型训练更高效、更快速。

使用DataLoader时需要传入一个Dataset对象,Dataset对象提供了访问样本数据的接口。例如,可以使用PyTorch提供的torchvision.datasets中的一些内置数据集(如MNIST、CIFAR-10等)来构建Dataset对象。

构建DataLoader对象需要指定batch_size参数,它决定了每个小批次的样本数;还可以设置shuffle参数来指定是否对样本进行洗牌(打乱顺序)。在构建DataLoader对象之后,可以通过for循环来遍历DataLoader对象,每次返回一个包含batch_size个样本的小批次数据。

示例代码:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# 加载MNIST数据集
mnist_train = MNIST('./data', train=True, download=True, transform=ToTensor())

# 构建DataLoader对象
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)

# 遍历DataLoader对象,训练模型
for x, y in train_loader:
    # 在这里执行模型训练操作
    pass

在上面的示例中,首先加载MNIST数据集,然后构建一个batch_size为64、打乱顺序的DataLoader对象。之后就可以使用for循环遍历DataLoader对象,每次返回64个样本数据进行模型训练。

除了基本参数之外,DataLoader对象还支持其他一些参数设置,例如num_workers(指定数据读取的线程数)、drop_last(指定是否丢弃最后一个不足batch_size的小批次)等。这些参数的具体用法可以参考PyTorch官方文档。

二、MNIST

2.1 数据加载

上述代码中的mnist_train = MNIST('./data', train=True, download=True, transform=ToTensor())是用来加载MNIST数据集的。

MNIST是PyTorch中提供的一个内置数据集类,用于加载手写数字识别的数据集。它可以通过torchvision.datasets模块进行导入。

MNIST构造函数中有四个参数:

  • root:该参数指定数据集所在的目录,如果在当前目录下不存在指定的数据集,则会从网站http://yann.lecun.com/exdb/mnist/上下载。
  • train:该参数指定加载训练集还是测试集(默认为True,即加载训练集)。
  • download:该参数指定是否需要自动下载数据集(默认为False)。
  • transform:用于指定数据集的预处理操作,这里使用ToTensor()将数据集转换成PyTorch中的张量格式。

在上述代码中,MNIST('./data', train=True, download=True, transform=ToTensor())的作用是:

  • 设置数据集的根目录为"./data",如果该目录下不存在相应的数据集,则自动下载。
  • 加载MNIST训练集。
  • 自动下载MNIST数据集到"./data"目录中,并将其转换为张量格式。

这里使用ToTensor()来转换为张量格式,它将原始数据集中的numpy数组格式转换成PyTorch中的张量格式。注意,数据集的每个样本都包含一张图片和一个对应的标签,被加载后的数据集可以直接作为DataLoader对象的参数。

2.2 训练样本和测试样本

训练数据集和测试数据集在机器学习中是用于训练和评估模型性能的不同数据集。

训练数据集是训练模型时使用的数据集,用于确定模型中的权重和参数。它应该涵盖尽可能多的数据,以便模型可以从中学习并适应新的、未见过的数据。通常情况下,训练数据集占总数据集的比例要更高。

测试数据集是用于评估模型性能和泛化能力的数据集。在模型训练结束后,使用测试数据集来检验模型在真实场景中的预测准确率,也就是用训练好的模型来预测测试集上的标签。与训练数据集不同的是,测试数据集通常是独立的、未被模型用于训练的样本。

区分训练数据集和测试数据集可以有效避免过度拟合模型的问题,即模型过度学习并记忆了训练数据集,但在新数据上的表现却很差。但这并不意味着我们应该完全将数据集划分为两部分:一个更好的做法是采用交叉验证,将训练数据集划分为多个折(fold)并进行交叉验证,以更好地评估模型的性能,并探索不同算法的性能差异。

在PyTorch中,MNIST是一个手写数字识别数据集。对于MNIST这个数据集,train参数用于指示是否加载训练数据集或测试数据集。

当train=True时,表示加载训练数据集,即用于训练模型的数据集,包括60000张手写数字的图像和对应的标签,可以被用来训练模型。

当train=False时,表示加载测试数据集,即用于评估训练模型性能的数据集,包括10000张手写数字的图像和对应的标签,可以被用来测试已训练好的模型在未见过的数据上的泛化能力。

训练集和测试集对应不同的数据文件。

在代码MNIST(‘./data’, train=True, download=True, transform=ToTensor())中,参数train=True表示加载训练数据集,并将数据集转换为PyTorch的tensor类型,并将其归一化(Normalize)到[0,1]的范围内。如果train=False,则会加载测试数据集并按照相同的方式进行处理。

因此,train参数的不同取值会影响数据集的加载方式,但不会影响数据预处理的方式。

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。