您现在的位置是:首页 >技术杂谈 >深度学习 -- Dataset与DataLoader网站首页技术杂谈

深度学习 -- Dataset与DataLoader

冲鸭嘟嘟可 2023-06-30 12:00:02
简介深度学习 -- Dataset与DataLoader

前言

在模型训练的步骤中,数据的部分非常重要,它的过程主要分为数据收集、数据划分、数据读取、数据预处理。

数据收集的有原始样本标签(Img,label)

数据集的划分需要分为训练集、验证集、测试集。
训练集负责训练模型,验证集负责验证模型是否过拟合,测试集是用来测试性能的。

数据读取主要就是DataLoader的内容

  • DataLoader分为两个子模块,分别是Sampler和DataSet
    Sampler的功能是生成索引(Index)
    DataSet则是根据索引来读取数据

数据预处理需要用transforms来实现

DataLoader 和 Dataset

DataLoader 和 Dataset是pytorch数据读取的核心

  • torch.utils.data.DataLoader

功能:构建可跌倒的数据装载器

在这里插入图片描述

  • dataset:Dataset类,决定数据从哪读取及如何读取
  • batchsize:批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

Epoch:所有训练样本都一输入到模型中,称之为一个Epoch
Iteration:一批样本输入到模型中,称之为一个Iteration
Batchsize:批大小,决定一个Epoch有多少个Iteration

Dataset

torch.utils.data.Dataset
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__ getitem __()

getitem:
接收一个索引,返回一个样本

在这里插入图片描述

数据读取的三个问题

在这里插入图片描述

1、读哪些数据?

2、从哪读数据?

通过os库对硬盘上的文件读取

3、怎么读数据?

if __name__ == '__main__':
	random.seed(1)

	dataset_dir = ps.path.join('..','data')
	split_dir = ps.path.join('..','split')
	train_dir = os.path.join(split_dir,'train')
	valid_dir = os.path.join(split_dir,'valid')
	test_dir = os.path.join(split_dir,'test')
	
	train_pct = 0.8
	valid_pct = 0.1
	test_pct = 0.1

构建MyDataset实例

train_data = MyDataset(data_dir=train_dir,transform=train_transform)
valid_data = MyDataset(data_dir=train_dir,transform=valid_transform)

构建DataLoder

train_loader = DataLoader(dataset=train_data,batch_size=tensor(32,32),shuffle=True)
valid_loader = DataLoader(dataset=valid_data,batch_size=tensor(32,32))

在这里插入图片描述

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。