您现在的位置是:首页 >技术杂谈 >深入浅出PyTorch数据读取机制网站首页技术杂谈

深入浅出PyTorch数据读取机制

穿着帆布鞋也能走猫步 2024-06-17 10:43:10
简介深入浅出PyTorch数据读取机制

熟悉深度学习的小伙伴一定都知道:深度学习模型训练主要由数据、模型、损失函数、优化器以及迭代训练五个模块组成。如下图所示,Pytorch数据读取机制则是数据模块中的主要分支。

在这里插入图片描述

Pytorch数据读取是通过​​Dataset​​​+​​Dataloader​​的方式完成。其中,

  • DataSet:定义数据集。将原始数据样本及对应标签映射到Dataset,便于后续通过index读取数据。同时,还可以在Dataset中进行数据格式变换、数据增强等预处理操作。
  • DataLoader:迭代读取数据集。将数据样本进行分批次Batch、打乱顺序Shuffle等处理,便于训练时迭代读取数据。

Dataset

Dataset用于解决数据从哪里读取以及如何读取的问题。 Pytorch给定的Dataset是一个抽象类,所有自定义的数据集都要继承Dataset,并重写**init()、getitem()和__len__()**类方法,以供DataLoader类直接调用。

  • init:数据集初始化。
  • getitem:定义指定索引如何获取样本数据,最终返回index对应的样本对{样本数据x:标签y}。
  • len():数据集的样本数。

下面是笔者以cifar10数据集为例实现Dataset自定义数据集的代码样例。

from torch.utils.data import Dataset
from PIL import Image
import os

class Mydata(Dataset):
    """
    步骤一:继承 torch.utils.data.Dataset 类
    """
    def __init__(self,data_dir,label_dir):
        """
        步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中
        """
        self.data_dir = data_dir
        self.label_dir = label_dir
        # 用join把路径拼接一起可以避免一些因“/”引发的错误
        self.path = os.path.join(self.data_dir,self.label_dir)
        # 将该路径下的所有文件变成一个列表
        self.img_path = os.listdir(self.path)

    def __getitem__(self,idx)
        """
        步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签)
        """
        # 根据index(idx),从列表中取出图片
        # img_path列表里每个元素就是对应图片文件名
        img_name = self.img_path[idx]
        # 获得对应图片路径
        img_item_path = os.path.join(self.data_dir,self.label_dir,img_name)
        # 使用PIL库下Image工具,打开对应路径图片
        img = Image.open(img_item_path)
        label = self.label_dir
        # 返回图片和对应标签
        return img,label

    def __len__(self):
        """
        步骤四:实现 __len__ 函数,返回数据集的样本总数
        """
        return len(self.img_path)

# data_dir,label_dir可自定义数据集目录
train_custom_dataset = MyData(data_dir,label_dir)
test_custom_dataset = MyData(data_dir,label_dir)

DataLoader

在实际项目中,当数据量很大,考虑到内存有限、I/O速度等问题,训练中不可能一次性将所有数据加载到内存或者只用一个进行加载数据,此时就需要的是多进程、迭代加载,Dataloader便应运而生。

DataLoader是一个可迭代的数据装载器,组合了数据集和采样器,并在给定数据集上提供可迭代对象。可以完成对数据集中多个对象的集成。

Pytorch的数据读取机制中DataLoader模块包括Sampler和Dataset两个子模块,其中Sampler模块生成索引index;Dataset模块是根据索引读取数据。DataLoader读取数据流程如下图所示。

在这里插入图片描述

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UZqgYimv-1684309723395)(imgs/230424183501.png)]

  • DataLoader:进入DataLoader模块。
  • DataloaderIter:进入__iter__函数判断是否采用多进程,并进入相应的读取机制。
  • Sampler:通过采样,挑选每个Batchsize该读取的数据,并返回这些数据的index。
  • index:一个batchsize数据的索引。
  • DatasetFetcher:获取index对应的数据。
  • Dataset:调用dataset[idx]获取相应数据,并拼接成list。
  • getitem:Dataset的核心,用索引获取数据。
  • img,label:读取到的数据。
  • collate_fn:将读取的数据从list转为batch形式。
  • Batch Data:batch形式数据,第一个元素是图像,第二个元素是标签。

Pytorch中DataLoader类定义如下:

class torch.utils.data.DataLoader(
     """
     构建可迭代的数据装载器,训练时,每一个for循环,每一次迭代,
     从DataLoader中获取一个batch_size大小的数据
     """
     dataset,
     batch_size=1,
     shuffle=False,
     sampler=None,
     batch_sampler=None,
     num_workers=0,
     collate_fn=None,
     pin_memory=False,
     drop_last=False,
)
  • dataset:需要加载的数据集,Dataset对象。
  • batch_size:每批次读取样本数。例如batch_size=16表示每批次读取16个样本。
  • shuffle:每个epoch是否乱序。shuffle=True表示在取数据时打乱样本顺序,以减少过拟合发生的可能。
  • sampler:索引index。
  • batch_sampler:将返回一个索引的sampler进行包装,按照设定的batch_size返回一组索引。
  • num_workers:同步/异步读取数据。num_workers=0表示数据加载是同步的,在主进程中完成。num_workers的值设为大于0时,即开启多进程方式异步加载数据,可提升数据读取速度。
  • pin_memory:是否将数据拷贝到拷贝到临时缓冲区。
  • collate_fn:将多个样本组合在一起变成一个mini-batch,不指定该函数的话会调用Pytorch内部默认的函数。
  • drop_last:丢弃不完整的批次样本,drop_last=True表示当数据集样本数不能被batch_size整除时,则丢弃最后一个不完整的batch样本。

补充说明

Epoch:所有训练样本都已输入到模型中,称为一个epoch
Iteration:一批样本(batch_size)输入到模型中,称为一个Iteration。
Batchsize:一批样本的大小,称为Batchsize。用于决定一个epoch有多少个Iteration。

代码实现示例如下。

import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

# 将数据集转换为torch可识别的类型
torch_dataset = Data.TensorDataset(x, y)

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print('epoch', epoch,
              '| step:', step,
              '| batch_x', batch_x.numpy(),
              '| batch_y:', batch_y.numpy())

在这里插入图片描述

通过上述方法即可初始化一个数据读取器loader,用于加载训练数据集torch_dataset。

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。