您现在的位置是:首页 >技术交流 >【youcans的深度学习 08】PyTorch 数据加载和转换网站首页技术交流

【youcans的深度学习 08】PyTorch 数据加载和转换

youcans_ 2023-05-23 04:00:02
简介【youcans的深度学习 08】PyTorch 数据加载和转换

欢迎关注『youcans的深度学习』系列



本节中我们讨论 PyTorch 中的 DataLoader 模块,如何处理各种形式的数据(例如 CSV 文件、图像、文本等)。

  • 数据集的处理
  • PyTorch 中的数据加载
  • 数据预处理
  • 创建自定义数据集

1. 机器学习中的数据集

机器学习中的“数据”几乎可以是你能想象到的任何东西,例如数字表格、图像、视频文件、音频文件、蛋白质结构、文本等等。

简单地说,机器学习需要把你的数据表示为数字形式,选择或构建一个模型从数据中学习。因此,数据是机器学习中知识的来源。

机器学习项目中的一个重要的步骤,是创建一个训练集和一个测试集,有时还要创建一个验证集。例如我们有了一个数据集,不论是流行的公开数据集还是自己建立的个人数据集,在建模之前都需要将其拆分。

  • 训练集(training set):用于模型训练中的学习,根据训练集数据计算梯度来优化模型参数,通常占数据集的 60-80%。

  • 验证集(Validation set):用于模型训练中的模型验证,常用于判断何时结束训练过程,但并不直接用于调整模型参数,通常占数据集的10-20%。

  • 测试集(testing set):用于对训练好的模型进行性能评估,常占数据集的10-20%。

需要特别说明的是:

(1)验证集和测试集数据的使用步骤是相同的:将验证集或测试集的样本数据作为模型的输入,使用模型进行正向计算(也即模型推理),得到模型的输出结果。

(2)验证集和测试集数据的目的是不同的:验证集通常用于模型训练过程,通过验证集的结果来决定是否进一步训练模型,以减少模型的过拟合。测试集的数据用于对已经训练好的模型(预训练模型)进行测试,用于评价模型的性能,不能根据测试结果再对模型进行调整。

(3)训练集和验证集数据都用于模型训练过程,但使用方法和目的都不同。训练集数据通常既要正向传播计算模型输出,又要将误差反向传播计算梯度,用于调整模型参数。验证集数据只做正向传播计算模型输出,不计算梯度也不做反向传播,计算结果并不用于调整模型参数。

(4)验证集并不是必须的。虽然通常认为在模型训练中设置验证集,有利于判断何时结束训练过程,以减少模型的过拟合,但这并不是必须的,也可以没有验证集。


2. PyTorch 图像分类数据集和自然语言数据集

使用通用的数据集(如 MNIST 或 CIFAR)训练神经网络,不仅可以显著地提高工作效率,而且通常可以获得更好的模型性能。这是由于通用数据集的样本结构均衡、信息高效,而且组织规范、易于处理。在处理自己建立的数据集时,就要困难的多。

  • torchvision 模块实现图像处理的核心类和方法,torchvision.datasets 包含了常用的一些图像数据集、图像模型架构和图像转换方法。
  • torchtext 模块实现自然语言处理所需的核心类和方法,torchtext.datasets 包含一些常用的文本分类数据集、语言模型数据集、机器翻译数据集、序列标注数据集、问答数据集和无监督学习数据集。
  • torchaudio 模块实现音频处理所需的核心类和方法,但该模块不太成熟。

2.1 Torchvision 中的图像分类数据集

PyTorch 提供了一些常用的图像数据集,预加载在 torchvision.datasets 类中。
torchvision 模块实现神经网络所需的核心类和方法, torchvision.datasets 包含流行的数据集、模型架构和常用的图像转换方法。

MNIST
MNIST: 经过标准化和中心裁剪的手写图像数据集,是机器学习中最常用的数据集之一。MNIST 具有 60,000 张训练图像和 10,000 张测试图像。
加载和使用MNIST数据集的方法为:

torchvision.datasets.MNIST()

Fashion MNIST
Fashion MNIST:该数据集与 MNIST 类似,但图像内容不是手写数字,而是 T 恤、裤子、包等服装类别。Fashion MNIST 具有 60,000 张训练图像和 10,000 张测试图像。
加载和使用 Fashion MNIST 数据集的方法为:

torchvision.datasets.FashionMNIST()

CIFAR
CIFAR:CIFAR 数据集包括卡车、青蛙、船、汽车、鹿等常见图像。CIFAR 数据集有两个版本:CIFAR10 和 CIFAR100。CIFAR10 有 10 个不同类别的图像,CIFAR100 有 100 个不同的类别。
加载和使用 CIFAR 数据集的方法为:

torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()

COCO
COCO:COCO 数据集有超过 100,000 个日常物品,如人、瓶子、文具、书籍等。COCO 图像数据集广泛用于对象检测和图像描述。
加载和使用 COCO 数据集的方法为:

torchvision.datasets.CocoCaptions()

EMNIST
EMNIST:是 MNIST 数据集的升级版,包含数字和字母的图像,非常适合文本识别。
加载和使用 EMNIST 数据集的方法为:

torchvision.datasets.EMNIST()

IMAGE-NET
IMAGE-NET: 该数据集具有超过 120 万张图像,包含 10,000 个类别。由于该数据集太大,单独的 CPU 难以处理,通常需要加载在高端硬件系统上。
加载和使用 IMAGE-NET 数据集的方法为:

torchvision.datasets.ImageNet()

以上是 torchvision 中最常用的数据集,还有数据集如:KMNIST、QMNIST、LSUN、STL10、SVHN、PhotoTour、SBU、Cityscapes、SBD、USPS、Kinetics-400。详见 PyTorch 官方文档。


2.2 Torchtext 中的文本分类数据集

PyTorch 提供了一些常用的文本数据集,预加载在 torchtext.datasets 类中。
torchtext 模块实现自然语言处理所需的核心类和方法, torchtext.datasets 包含一些常用的文本分类数据集、语言模型数据集、机器翻译数据集、序列标注数据集、问答数据集和无监督学习数据集。

IMDB
IMDB:二分类的文本分类数据集,内容是高度极端的电影评论,用于情绪分类。数据集包含 25,000 个训练样本和 25,000 条测试样本。
加载和使用 IMDB 数据集的方法为:

torchtext.datasets.IMDB()

WikiText2
WikiText2:语言模型数据集,摘自维基百科并保留了标点符号和实际的字母大小写,广泛应用于长期依赖研究。训练集36718,验证集3760,测试集4358。
加载和使用 WikiText2 数据集的方法为:

torchtext.datasets.WikiText2(root=‘.data’, split=(‘train’, ‘valid’, ‘test’))

WikiText103
WikiText103:语言模型数据集,是 WikiText2 的升级版,从维基百科的 Good 与 Featured 文章中提炼。超过 1 亿个语句的数据合集。
加载和使用 WikiText103 数据集的方法为:

torchtext.datasets.WikiText103(root=‘.data’, split=(‘train’, ‘valid’, ‘test’))

此外, torchtext.datasets 类中的数据集还有:Multi30k、UDPOS、CoNLL2000Chunking、SQuAD、ENWik9 等。


3. PyTorch 数据集的加载与转换

3.1 DataLoader 加载 MNIST 数据集

在 torch.utils.data 模块中提供了DataLoader类,用来加载和迭代数据集,可以节省内存。

from torch.utils.data import DataLoader
DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False)

参数说明:

  • dataset:数据集,加载数据的路径参数。
  • batch_size:在一次迭代中训练样本的数量。
  • shuffle:是否重排数据,布尔值,True表示所有打乱样本顺序加载,False表示按顺序加载。
  • num_worker:同时运行的进程数,默认为 0 表示只使用主进程,允许使用多进程。
  • collate_fn:合并数据集,可选参数,用于合并样本列表。
  • pin_memory:在 CUDA 张量上加载数据,True表示在返回之前将张量复制到 CUDA 内存中。

下载 MNIST 数据集
从 torchvision 下载 MNIST 数据集。
下载训练数据集时,设置参数 train=True;下载测试数据集时,设置参数 train=False。

# (2) 下载 MNIST 数据集 -- 网络加载
from torchvision.datasets import MNIST
# 将数据集加载到 DataLoader
data_train = MNIST('~/mnist_data', train=True, download=True)

# 查看和显示数据集中的样本图像
random_image = data_train[0][0]  # 样本[0] 的图像
random_image_label = data_train[0][1]  # # 样本[0] 的标签
print("The shape of the image is:", random_image.shape)
print("The label of the image is:", random_image_label)
plt.imshow(random_image)
plt.show()

运行后显示如下:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to C:UsersDavid/mnist_dataMNIST aw rain-images-idx3-ubyte.gz

加载 MNIST 数据集

使用 DataLoader 类加载 MNIST 数据集。

Dataloader 是一个迭代器,基本功能是传入一个 Dataset 对象,根据参数 batch_size 生成一个 batch 的数据。

# (3) 使用 DataLoader 类加载 MNIST 数据集 -- 网络加载
import torch
from torchvision import transforms

data_train = torch.utils.data.DataLoader(
    MNIST('~/mnist_data', 
          train=True, download=True, 
          transform = transforms.Compose([transforms.ToTensor()])),
    batch_size=64,
    shuffle=True)

for batch_idx, samples in enumerate(data_train):
      print(batch_idx, samples)

CUDA 加载 MNIST 数据集

GPU 进行模型训练的速度更快,使用 CUDA 加载数据的配置如下。

import torch
from torchvision import transforms

# test cuda
device = "cuda" if torch.cuda.is_available() else "cpu"
kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('/files/', train=True, download=True),
    batch_size=batch_size_train, **kwargs)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('files/', train=False, download=True),
    batch_size=batch_size, **kwargs)

这些 Datasets 是 torch.utils.data.Dataset 的子类,可以通过 torch.utils.data.DataLoader 使用多线程,例如:

torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)


3.2 用 transforms 进行图像格式转换

数据集中的图片格式是 PIL格式,需要转换为 tensor 张量格式。
在模型训练时,所有训练样本的大小要相同。但是,图像数据集中的图片可能具有不同的大小和分辨率。使用 transforms 可以将数据集中所有图像都转换为规定的大小和分辨率。

常用的操作是:

  • transforms.Resize():图像缩放,调整大小
  • transforms.CenterCrop():使用 CenterCrop 变换从中心裁剪图像
  • transforms.RandomResizedCrop():随机调整数据集中所有图像的大小

torchvision.datasets 加载 CIFAR10 并进行转换的过程如下。

  • Resize,将所有图像的大小调整为 32×32
  • CenterCrop,对图像应用中心裁剪变换
  • ToTensor,将裁剪图像转换为张量数据类型
  • Normalize,对张量规范化,使其介于 0.5 到 1 之间
# (4) 使用 transforms 对数据进行预处理 -- 网络加载
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 定义一个名为 transforms 的变量
transform = transforms.Compose([    
    transforms.Resize(32),  # resize 32×32    
    transforms.CenterCrop(32),  # center-crop 裁剪变换    
    transforms.ToTensor(),  # to-tensor   
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])  # 规范化

# 将数据集加载到 DataLoader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
           download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, 
              batch_size=4, shuffle=False)

3.3 用 tensorboard 进行图片显示

在加载数据集、进行格式转换之后,可以使用 add_image() 方法显示图片,这需要 tensorboard 包的支持。

在这里我们使用for循环进行10张图片的显示

# (5) 用  tensorboard 进行图片显示 -- 网络加载
import torchvision

# 定义transforms
trans_totensor_tool = torchvision.transforms.ToTensor()

# 创建训练数据集
train_set = 
    torchvision.datasets.CIFAR10(root="./dataset3",
    train=True, transform=trans_totensor_tool, download=True)
# 创建测试数据集
test_set = 
    torchvision.datasets.CIFAR10(root="./dataset3",
    train=False, transform=trans_totensor_tool, download=True)
 
# 显示图片
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(10):  # 循环显示 10 张图片
    img,label = test_set[i]
    writer.add_image("dataset",img,i) 
writer.close()

4. 本地数据集的加载与转换

4.1 下载 PyTorch 数据集到本地

由于某些原因,在境内直接加载 PyTorch 数据集的速度很慢,甚至可能无法下载。对此可以从网络链接直接下载数据集保存到本地,然后从本地文件加载数据集。

在 PyCharm 中下载速度很慢,可以使用迅雷等下载工具快速下载:

  • 按住Ctrl 键,鼠标左键点击数据集的单词,就可以进入数据集的帮助文档,获得下载文件名和下载链接;
  • 使用迅雷或者浏览器下载压缩文件,按照 root 中定义的路径创建文件夹,把文件放入文件夹中。

在这里插入图片描述

注意:创建的文件夹必须与 root 中定义的文件夹名称相同,否则扫描不到该数据集。

无论是否需要在线下载数据集,都推荐设置参数 download=True,可以自动完成下载解压工作。如果已经下载了数据集,这样也可以提供解压功能。


4.2 下载到本地的 PyTorch 数据集的加载和显示

使用 torchvision.datasets 不能加载本地数据集,我们可以重写一个自定义的 MyDataset 类。其作用类似于 torchvision.datasets.MNIST,可以加载预下载的 PyTorch 数据集。
注意,该方法适用于 “下载到本地的 PyTorch 数据集”,而不是自己创建的用户数据集。

class MyDataset(Dataset):
    def __init__(self, folder, data_name, label_name, transform=None):
        (train_set, train_labels) = load_data(folder, data_name, label_name)  # 可以使用 torch.load() 读取为torch.Tensor
        self.train_set = train_set
        self.train_labels = train_labels
        self.transform = transform

    def __getitem__(self, index):
        img, target = self.train_set[index], int(self.train_labels[index])
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.train_set)

def load_data(data_folder, data_name, label_name):
    """
        data_folder: 文件目录
        data_name: 数据文件名
        label_name:标签数据文件名
    """
    with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath:  # rb:读取二进制数据
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:
        x_train = np.frombuffer(
            imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
    return (x_train, y_train)

我们如 3.1将 MNIST 数据集下载并保存到本地路径 “…dataset”,就可以使用 MyDataset 类加载。

# (6) 使用预先下载的 MNIST 数据集
from torchvision.transforms import ToTensor
train_dataset = MyDataset("..dataset", "train-images-idx3-ubyte.gz",
                           "train-labels-idx1-ubyte.gz", transform=ToTensor())
test_dataset = MyDataset("..dataset", "t10k-images-idx3-ubyte.gz",
                           "t10k-labels-idx1-ubyte.gz", transform=ToTensor())
# 显示图片
plt.figure(figsize=(8, 5))
for i in range(10):  # 循环显示 10 张图片
    img, label = train_dataset[i]  # img 类型为 torch
    print(label, img.shape, img.numpy().shape)
    plt.subplot(2,5,i+1), plt.imshow(img[0].numpy())
plt.show()

加载后的使用与操作,与网络加载 PyTorch 数据集是相同的,可以使用 torch.utils.data.DataLoader 将数据集加载到 DataLoader 迭代生成器,使用 next迭代取出样本数据。

# (7) 使用预先下载的 MNIST 数据集
from torchvision.transforms import ToTensor
train_dataset = MyDataset("..dataset", "train-images-idx3-ubyte.gz",
                           "train-labels-idx1-ubyte.gz", transform=ToTensor())
test_dataset = MyDataset("..dataset", "t10k-images-idx3-ubyte.gz",
                           "t10k-labels-idx1-ubyte.gz", transform=ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, shuffle=True, batch_size=64)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False, batch_size=64)

imgs, labels = next(iter(train_loader))  # 创建生成器,用 next 返回一个批次的数据
print(imgs.shape)  # torch.Size([64, 1, 28, 28])
print(labels.shape)  # torch.Size([64])

plt.figure(figsize=(8, 5))
for i, img in enumerate(imgs[:10]):
    imgNP = img.numpy()  # 将 tensor 张量转为 numpy 数组
    imgNP2 = np.squeeze(imgNP)  # (1,28,28) -> (28,28)
    plt.subplot(2,5,i+1)
    plt.imshow(imgNP2)  # 绘制第 i 张图片
plt.show()

在这里插入图片描述


5. 自定义数据集的创建与加载

注意本节 “自定义数据集”不是指上节中的“下载到本地的 PyTorch 数据集”,而是指自己创建的用户数据集。这种方法调用Pytorch 内部的 API,数据集需要符合 API 规定的存放格式。

5.1 用 ImageFolder 类创建自定义数据集

创建自定义数据集,需要将数据集按照一定的结构和格式存放,并调用 ImageFolder 类。

ImageFoldertorchvision 中的通用数据加载器类,可以加载自己的图像数据集。ImageFolder 要求数据集按照如下方式组织。根目录 root 下存储的是类别文件夹(如class1, class2, …),每个类别文件夹下存储相应类别的图像(如xxx.png)

A generic data loader where the images are arranged in this way:

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

函数原型:

dataset=torchvision.datasets.ImageFolder(
root, transform=None,
target_transform=None,
loader=,
is_valid_file=None)

参数说明:

  1. root 是根目录,在 root 目录下设有不同类别的子文件夹 :
data(root)
├── train
    ├── cat
    │   ├── cat_image1.png
    │   └── cat_image1.png
    ├── dog
    │   └── dog_image1.png
    │   └── dog_image2.png
    │   └── dog_image3.png
├── valid
    ├── cat
    │   ├── cat_image_valid1.png
    │   └── cat_image_valid2.png
    ├── dog
    │   └── dog_image_valid1.png
    ...

  1. transform:对图片进行预处理,返回 transform 变换后的图片。
  2. target_transform:对图片类别进行预处理,输入为 target,输出转换后的类别索引,默认返回顺序索引 0,1, 2…。
  3. loader:数据集加载方式,通常使用默认加载方式 。
  4. is_valid_file:检查损坏文件,获取图像文件路径并检查是否有效。

建立上述结构的数据文件目录之后,可以使用 ImageFolder 加载该数据集的所有图像。

torchvision.datasets.ImageFolder(root, transform)

ImageFolder 类有三种属性:

  1. self.classes:用一个 list 保存类别名称;
  2. self.class_to_idx:类别对应的索引,与不做转换返回顺序索引的 target 对应;
  3. self.imgs:保存元组 (img_path, class) 的 list。

5.2 用 ImageFolder 加载自定义数据集的完整例程

# DataLoad02_v1.py
# Load custom Datasets with ImageFolder
# 加载自定义数据集
# Copyright: youcans@qq.com
# Crated: Huang Shan, 2023/03/06

# (1) 导入需要的库
import os
import torch
from torchvision import transforms, datasets
# from torch.utils.data import DataLoader, ImageFolder


# (2) 定义图片转换方式
train_transforms = transforms.Compose([
	transforms.RandomResizedCrop(400),
	transforms.ToTensor()])

# (3) 定义数据库的路径
path = os.path.join(os.getcwd(), "userdataset", "train")

# (4) 从指定路径导入自定义的数据库
train_dataset = torch.utils.data.ImageFolder(root=path, transform=train_transforms)

# (5) 查看自定义的数据库
print(train_dataset.classes)   # 根据文件夹的名称所确定的类别
print(train_dataset.class_to_idx)  # 顺序索引 0,1,2...
print(train_dataset.imgs)  # 元组,图片的路径和类别

# (6) 用 DataLoader 加载数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
               batch_size=64, shuffle=True)
imgs, labels = next(iter(train_loader))  # 创建生成器,用 next 返回一个批次的数据
print(imgs.shape)  # torch.Size([64, 1, 28, 28])
print(labels.shape)  # torch.Size([64])

# # (7) 传入网络模型进行训练
# for epoch in range(epochs):
#     train_bar = tqdm(train_loader, file=sys.stdout)
#     for step, data in enumerate(train_bar):
#     ...



5.3 自定义数据集的封装

torch.utils.data.Dataset是代表数据集的抽象类。我们创建自定义数据集,要继承 Dataset类并封装 其中的__getitem__()__len__()方法。

  • __getitem__() 方法支持索引,可以使用 dataset[i] 返回数据集中的样本 i。
  • __len__() 方法,可以使用 len(dataset) 返回数据集的大小。

定义 Dataset 类的基本结构如下:

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

下面以封装 FruitImagesDataset 数据集,示例创建自定义数据集的模板。

import os
import numpy as np
import cv2
import torch
import matplotlib.patches as patches
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from xml.etree import ElementTree as et
from torchvision import transforms as torchtrans

class FruitImagesDataset(torch.utils.data.Dataset):
    def __init__(self, files_dir, width, height, transforms=None):
        self.transforms = transforms
        self.files_dir = files_dir
        self.height = height
        self.width = width


        self.imgs = [image for image in sorted(os.listdir(files_dir))
                     if image[-4:] == '.jpg']

        self.classes = ['_','apple', 'banana', 'orange']

    def __getitem__(self, idx):

        img_name = self.imgs[idx]
        image_path = os.path.join(self.files_dir, img_name)

        # reading the images and converting them to correct size and color
        img = cv2.imread(image_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
        # diving by 255
        img_res /= 255.0

        # annotation file
        annot_filename = img_name[:-4] + '.xml'
        annot_file_path = os.path.join(self.files_dir, annot_filename)

        boxes = []
        labels = []
        tree = et.parse(annot_file_path)
        root = tree.getroot()

        # cv2 image gives size as height x width
        wt = img.shape[1]
        ht = img.shape[0]

        # box coordinates for xml files are extracted and corrected for image size given
        for member in root.findall('object'):
            labels.append(self.classes.index(member.find('name').text))

            # bounding box
            xmin = int(member.find('bndbox').find('xmin').text)
            xmax = int(member.find('bndbox').find('xmax').text)

            ymin = int(member.find('bndbox').find('ymin').text)
            ymax = int(member.find('bndbox').find('ymax').text)

            xmin_corr = (xmin / wt) * self.width
            xmax_corr = (xmax / wt) * self.width
            ymin_corr = (ymin / ht) * self.height
            ymax_corr = (ymax / ht) * self.height

            boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])

        # convert boxes into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        # getting the areas of the boxes
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # suppose all instances are not crowd
        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)

        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["area"] = area
        target["iscrowd"] = iscrowd
        # image_id
        image_id = torch.tensor([idx])
        target["image_id"] = image_id

        if self.transforms:
            sample = self.transforms(image=img_res,
                                     bboxes=target['boxes'],
                                     labels=labels)

            img_res = sample['image']
            target['boxes'] = torch.Tensor(sample['bboxes'])
        return img_res, target
    def __len__(self):
        return len(self.imgs)

def get_transform(train):
    if train:
        return A.Compose([
            A.HorizontalFlip(0.5),
            ToTensorV2(p=1.0)
        ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
    else:
        return A.Compose([
            ToTensorV2(p=1.0)
        ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

files_dir = '../input/fruit-images-for-object-detection/train_zip/train'
test_dir = '../input/fruit-images-for-object-detection/test_zip/test'

dataset = FruitImagesDataset(train_dir, 480, 480)


【本节完】


版权声明:
欢迎关注『youcans的深度学习』系列,转发请注明原文链接:
【youcans的深度学习 08】PyTorch 数据加载和转换(https://youcans.blog.csdn.net/article/details/130217675)
Copyright 2023 youcans, XUPT
Crated:2023-04-18


风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。