您现在的位置是:首页 >技术交流 >【youcans的深度学习 08】PyTorch 数据加载和转换网站首页技术交流
【youcans的深度学习 08】PyTorch 数据加载和转换
欢迎关注『youcans的深度学习』系列
【youcans的深度学习 08】PyTorch 数据加载和转换
本节中我们讨论 PyTorch 中的 DataLoader 模块,如何处理各种形式的数据(例如 CSV 文件、图像、文本等)。
- 数据集的处理
- PyTorch 中的数据加载
- 数据预处理
- 创建自定义数据集
1. 机器学习中的数据集
机器学习中的“数据”几乎可以是你能想象到的任何东西,例如数字表格、图像、视频文件、音频文件、蛋白质结构、文本等等。
简单地说,机器学习需要把你的数据表示为数字形式,选择或构建一个模型从数据中学习。因此,数据是机器学习中知识的来源。
机器学习项目中的一个重要的步骤,是创建一个训练集和一个测试集,有时还要创建一个验证集。例如我们有了一个数据集,不论是流行的公开数据集还是自己建立的个人数据集,在建模之前都需要将其拆分。
-
训练集(training set):用于模型训练中的学习,根据训练集数据计算梯度来优化模型参数,通常占数据集的 60-80%。
-
验证集(Validation set):用于模型训练中的模型验证,常用于判断何时结束训练过程,但并不直接用于调整模型参数,通常占数据集的10-20%。
-
测试集(testing set):用于对训练好的模型进行性能评估,常占数据集的10-20%。
需要特别说明的是:
(1)验证集和测试集数据的使用步骤是相同的:将验证集或测试集的样本数据作为模型的输入,使用模型进行正向计算(也即模型推理),得到模型的输出结果。
(2)验证集和测试集数据的目的是不同的:验证集通常用于模型训练过程,通过验证集的结果来决定是否进一步训练模型,以减少模型的过拟合。测试集的数据用于对已经训练好的模型(预训练模型)进行测试,用于评价模型的性能,不能根据测试结果再对模型进行调整。
(3)训练集和验证集数据都用于模型训练过程,但使用方法和目的都不同。训练集数据通常既要正向传播计算模型输出,又要将误差反向传播计算梯度,用于调整模型参数。验证集数据只做正向传播计算模型输出,不计算梯度也不做反向传播,计算结果并不用于调整模型参数。
(4)验证集并不是必须的。虽然通常认为在模型训练中设置验证集,有利于判断何时结束训练过程,以减少模型的过拟合,但这并不是必须的,也可以没有验证集。
2. PyTorch 图像分类数据集和自然语言数据集
使用通用的数据集(如 MNIST 或 CIFAR)训练神经网络,不仅可以显著地提高工作效率,而且通常可以获得更好的模型性能。这是由于通用数据集的样本结构均衡、信息高效,而且组织规范、易于处理。在处理自己建立的数据集时,就要困难的多。
torchvision
模块实现图像处理的核心类和方法,torchvision.datasets
包含了常用的一些图像数据集、图像模型架构和图像转换方法。torchtext
模块实现自然语言处理所需的核心类和方法,torchtext.datasets
包含一些常用的文本分类数据集、语言模型数据集、机器翻译数据集、序列标注数据集、问答数据集和无监督学习数据集。torchaudio
模块实现音频处理所需的核心类和方法,但该模块不太成熟。
2.1 Torchvision 中的图像分类数据集
PyTorch 提供了一些常用的图像数据集,预加载在 torchvision.datasets
类中。
torchvision
模块实现神经网络所需的核心类和方法, torchvision.datasets
包含流行的数据集、模型架构和常用的图像转换方法。
MNIST
MNIST: 经过标准化和中心裁剪的手写图像数据集,是机器学习中最常用的数据集之一。MNIST 具有 60,000 张训练图像和 10,000 张测试图像。
加载和使用MNIST数据集的方法为:
torchvision.datasets.MNIST()
Fashion MNIST
Fashion MNIST:该数据集与 MNIST 类似,但图像内容不是手写数字,而是 T 恤、裤子、包等服装类别。Fashion MNIST 具有 60,000 张训练图像和 10,000 张测试图像。
加载和使用 Fashion MNIST 数据集的方法为:
torchvision.datasets.FashionMNIST()
CIFAR
CIFAR:CIFAR 数据集包括卡车、青蛙、船、汽车、鹿等常见图像。CIFAR 数据集有两个版本:CIFAR10 和 CIFAR100。CIFAR10 有 10 个不同类别的图像,CIFAR100 有 100 个不同的类别。
加载和使用 CIFAR 数据集的方法为:
torchvision.datasets.CIFAR10()
torchvision.datasets.CIFAR100()
COCO
COCO:COCO 数据集有超过 100,000 个日常物品,如人、瓶子、文具、书籍等。COCO 图像数据集广泛用于对象检测和图像描述。
加载和使用 COCO 数据集的方法为:
torchvision.datasets.CocoCaptions()
EMNIST
EMNIST:是 MNIST 数据集的升级版,包含数字和字母的图像,非常适合文本识别。
加载和使用 EMNIST 数据集的方法为:
torchvision.datasets.EMNIST()
IMAGE-NET
IMAGE-NET: 该数据集具有超过 120 万张图像,包含 10,000 个类别。由于该数据集太大,单独的 CPU 难以处理,通常需要加载在高端硬件系统上。
加载和使用 IMAGE-NET 数据集的方法为:
torchvision.datasets.ImageNet()
以上是 torchvision 中最常用的数据集,还有数据集如:KMNIST、QMNIST、LSUN、STL10、SVHN、PhotoTour、SBU、Cityscapes、SBD、USPS、Kinetics-400。详见 PyTorch 官方文档。
2.2 Torchtext 中的文本分类数据集
PyTorch 提供了一些常用的文本数据集,预加载在 torchtext.datasets
类中。
torchtext
模块实现自然语言处理所需的核心类和方法, torchtext.datasets
包含一些常用的文本分类数据集、语言模型数据集、机器翻译数据集、序列标注数据集、问答数据集和无监督学习数据集。
IMDB
IMDB:二分类的文本分类数据集,内容是高度极端的电影评论,用于情绪分类。数据集包含 25,000 个训练样本和 25,000 条测试样本。
加载和使用 IMDB 数据集的方法为:
torchtext.datasets.IMDB()
WikiText2
WikiText2:语言模型数据集,摘自维基百科并保留了标点符号和实际的字母大小写,广泛应用于长期依赖研究。训练集36718,验证集3760,测试集4358。
加载和使用 WikiText2 数据集的方法为:
torchtext.datasets.WikiText2(root=‘.data’, split=(‘train’, ‘valid’, ‘test’))
WikiText103
WikiText103:语言模型数据集,是 WikiText2 的升级版,从维基百科的 Good 与 Featured 文章中提炼。超过 1 亿个语句的数据合集。
加载和使用 WikiText103 数据集的方法为:
torchtext.datasets.WikiText103(root=‘.data’, split=(‘train’, ‘valid’, ‘test’))
此外, torchtext.datasets
类中的数据集还有:Multi30k、UDPOS、CoNLL2000Chunking、SQuAD、ENWik9 等。
3. PyTorch 数据集的加载与转换
3.1 DataLoader 加载 MNIST 数据集
在 torch.utils.data 模块中提供了DataLoader
类,用来加载和迭代数据集,可以节省内存。
from torch.utils.data import DataLoader
DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=None,
pin_memory=False)
参数说明:
- dataset:数据集,加载数据的路径参数。
- batch_size:在一次迭代中训练样本的数量。
- shuffle:是否重排数据,布尔值,True表示所有打乱样本顺序加载,False表示按顺序加载。
- num_worker:同时运行的进程数,默认为 0 表示只使用主进程,允许使用多进程。
- collate_fn:合并数据集,可选参数,用于合并样本列表。
- pin_memory:在 CUDA 张量上加载数据,True表示在返回之前将张量复制到 CUDA 内存中。
下载 MNIST 数据集
从 torchvision 下载 MNIST 数据集。
下载训练数据集时,设置参数 train=True;下载测试数据集时,设置参数 train=False。
# (2) 下载 MNIST 数据集 -- 网络加载
from torchvision.datasets import MNIST
# 将数据集加载到 DataLoader
data_train = MNIST('~/mnist_data', train=True, download=True)
# 查看和显示数据集中的样本图像
random_image = data_train[0][0] # 样本[0] 的图像
random_image_label = data_train[0][1] # # 样本[0] 的标签
print("The shape of the image is:", random_image.shape)
print("The label of the image is:", random_image_label)
plt.imshow(random_image)
plt.show()
运行后显示如下:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to C:UsersDavid/mnist_dataMNIST aw rain-images-idx3-ubyte.gz
加载 MNIST 数据集
使用 DataLoader 类加载 MNIST 数据集。
Dataloader 是一个迭代器,基本功能是传入一个 Dataset 对象,根据参数 batch_size 生成一个 batch 的数据。
# (3) 使用 DataLoader 类加载 MNIST 数据集 -- 网络加载
import torch
from torchvision import transforms
data_train = torch.utils.data.DataLoader(
MNIST('~/mnist_data',
train=True, download=True,
transform = transforms.Compose([transforms.ToTensor()])),
batch_size=64,
shuffle=True)
for batch_idx, samples in enumerate(data_train):
print(batch_idx, samples)
CUDA 加载 MNIST 数据集
GPU 进行模型训练的速度更快,使用 CUDA 加载数据的配置如下。
import torch
from torchvision import transforms
# test cuda
device = "cuda" if torch.cuda.is_available() else "cpu"
kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('/files/', train=True, download=True),
batch_size=batch_size_train, **kwargs)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('files/', train=False, download=True),
batch_size=batch_size, **kwargs)
这些 Datasets 是 torch.utils.data.Dataset 的子类,可以通过 torch.utils.data.DataLoader 使用多线程,例如:
torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
3.2 用 transforms 进行图像格式转换
数据集中的图片格式是 PIL格式,需要转换为 tensor 张量格式。
在模型训练时,所有训练样本的大小要相同。但是,图像数据集中的图片可能具有不同的大小和分辨率。使用 transforms 可以将数据集中所有图像都转换为规定的大小和分辨率。
常用的操作是:
- transforms.Resize():图像缩放,调整大小
- transforms.CenterCrop():使用 CenterCrop 变换从中心裁剪图像
- transforms.RandomResizedCrop():随机调整数据集中所有图像的大小
从 torchvision.datasets
加载 CIFAR10 并进行转换的过程如下。
- Resize,将所有图像的大小调整为 32×32
- CenterCrop,对图像应用中心裁剪变换
- ToTensor,将裁剪图像转换为张量数据类型
- Normalize,对张量规范化,使其介于 0.5 到 1 之间
# (4) 使用 transforms 对数据进行预处理 -- 网络加载
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 定义一个名为 transforms 的变量
transform = transforms.Compose([
transforms.Resize(32), # resize 32×32
transforms.CenterCrop(32), # center-crop 裁剪变换
transforms.ToTensor(), # to-tensor
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]) # 规范化
# 将数据集加载到 DataLoader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=4, shuffle=False)
3.3 用 tensorboard 进行图片显示
在加载数据集、进行格式转换之后,可以使用 add_image()
方法显示图片,这需要 tensorboard 包的支持。
在这里我们使用for循环进行10张图片的显示
# (5) 用 tensorboard 进行图片显示 -- 网络加载
import torchvision
# 定义transforms
trans_totensor_tool = torchvision.transforms.ToTensor()
# 创建训练数据集
train_set =
torchvision.datasets.CIFAR10(root="./dataset3",
train=True, transform=trans_totensor_tool, download=True)
# 创建测试数据集
test_set =
torchvision.datasets.CIFAR10(root="./dataset3",
train=False, transform=trans_totensor_tool, download=True)
# 显示图片
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(10): # 循环显示 10 张图片
img,label = test_set[i]
writer.add_image("dataset",img,i)
writer.close()
4. 本地数据集的加载与转换
4.1 下载 PyTorch 数据集到本地
由于某些原因,在境内直接加载 PyTorch 数据集的速度很慢,甚至可能无法下载。对此可以从网络链接直接下载数据集保存到本地,然后从本地文件加载数据集。
在 PyCharm 中下载速度很慢,可以使用迅雷等下载工具快速下载:
- 按住
Ctrl
键,鼠标左键点击数据集的单词,就可以进入数据集的帮助文档,获得下载文件名和下载链接; - 使用迅雷或者浏览器下载压缩文件,按照 root 中定义的路径创建文件夹,把文件放入文件夹中。
注意:创建的文件夹必须与 root 中定义的文件夹名称相同,否则扫描不到该数据集。
无论是否需要在线下载数据集,都推荐设置参数 download=True,可以自动完成下载解压工作。如果已经下载了数据集,这样也可以提供解压功能。
4.2 下载到本地的 PyTorch 数据集的加载和显示
使用 torchvision.datasets
不能加载本地数据集,我们可以重写一个自定义的 MyDataset 类。其作用类似于 torchvision.datasets.MNIST
,可以加载预下载的 PyTorch 数据集。
注意,该方法适用于 “下载到本地的 PyTorch 数据集”,而不是自己创建的用户数据集。
class MyDataset(Dataset):
def __init__(self, folder, data_name, label_name, transform=None):
(train_set, train_labels) = load_data(folder, data_name, label_name) # 可以使用 torch.load() 读取为torch.Tensor
self.train_set = train_set
self.train_labels = train_labels
self.transform = transform
def __getitem__(self, index):
img, target = self.train_set[index], int(self.train_labels[index])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.train_set)
def load_data(data_folder, data_name, label_name):
"""
data_folder: 文件目录
data_name: 数据文件名
label_name:标签数据文件名
"""
with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath: # rb:读取二进制数据
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
return (x_train, y_train)
我们如 3.1将 MNIST 数据集下载并保存到本地路径 “…dataset”,就可以使用 MyDataset 类加载。
# (6) 使用预先下载的 MNIST 数据集
from torchvision.transforms import ToTensor
train_dataset = MyDataset("..dataset", "train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz", transform=ToTensor())
test_dataset = MyDataset("..dataset", "t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", transform=ToTensor())
# 显示图片
plt.figure(figsize=(8, 5))
for i in range(10): # 循环显示 10 张图片
img, label = train_dataset[i] # img 类型为 torch
print(label, img.shape, img.numpy().shape)
plt.subplot(2,5,i+1), plt.imshow(img[0].numpy())
plt.show()
加载后的使用与操作,与网络加载 PyTorch 数据集是相同的,可以使用 torch.utils.data.DataLoader
将数据集加载到 DataLoader 迭代生成器,使用 next
迭代取出样本数据。
# (7) 使用预先下载的 MNIST 数据集
from torchvision.transforms import ToTensor
train_dataset = MyDataset("..dataset", "train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz", transform=ToTensor())
test_dataset = MyDataset("..dataset", "t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz", transform=ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, shuffle=True, batch_size=64)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False, batch_size=64)
imgs, labels = next(iter(train_loader)) # 创建生成器,用 next 返回一个批次的数据
print(imgs.shape) # torch.Size([64, 1, 28, 28])
print(labels.shape) # torch.Size([64])
plt.figure(figsize=(8, 5))
for i, img in enumerate(imgs[:10]):
imgNP = img.numpy() # 将 tensor 张量转为 numpy 数组
imgNP2 = np.squeeze(imgNP) # (1,28,28) -> (28,28)
plt.subplot(2,5,i+1)
plt.imshow(imgNP2) # 绘制第 i 张图片
plt.show()
5. 自定义数据集的创建与加载
注意本节 “自定义数据集”不是指上节中的“下载到本地的 PyTorch 数据集”,而是指自己创建的用户数据集。这种方法调用Pytorch 内部的 API,数据集需要符合 API 规定的存放格式。
5.1 用 ImageFolder 类创建自定义数据集
创建自定义数据集,需要将数据集按照一定的结构和格式存放,并调用 ImageFolder 类。
ImageFolder
是 torchvision
中的通用数据加载器类,可以加载自己的图像数据集。ImageFolder
要求数据集按照如下方式组织。根目录 root 下存储的是类别文件夹(如class1, class2, …),每个类别文件夹下存储相应类别的图像(如xxx.png)
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
函数原型:
dataset=torchvision.datasets.ImageFolder(
root, transform=None,
target_transform=None,
loader=,
is_valid_file=None)
参数说明:
- root 是根目录,在 root 目录下设有不同类别的子文件夹 :
data(root)
├── train
├── cat
│ ├── cat_image1.png
│ └── cat_image1.png
├── dog
│ └── dog_image1.png
│ └── dog_image2.png
│ └── dog_image3.png
├── valid
├── cat
│ ├── cat_image_valid1.png
│ └── cat_image_valid2.png
├── dog
│ └── dog_image_valid1.png
...
- transform:对图片进行预处理,返回 transform 变换后的图片。
- target_transform:对图片类别进行预处理,输入为 target,输出转换后的类别索引,默认返回顺序索引 0,1, 2…。
- loader:数据集加载方式,通常使用默认加载方式 。
- is_valid_file:检查损坏文件,获取图像文件路径并检查是否有效。
建立上述结构的数据文件目录之后,可以使用 ImageFolder
加载该数据集的所有图像。
torchvision.datasets.ImageFolder(root, transform)
ImageFolder 类有三种属性:
- self.classes:用一个 list 保存类别名称;
- self.class_to_idx:类别对应的索引,与不做转换返回顺序索引的 target 对应;
- self.imgs:保存元组 (img_path, class) 的 list。
5.2 用 ImageFolder 加载自定义数据集的完整例程
# DataLoad02_v1.py
# Load custom Datasets with ImageFolder
# 加载自定义数据集
# Copyright: youcans@qq.com
# Crated: Huang Shan, 2023/03/06
# (1) 导入需要的库
import os
import torch
from torchvision import transforms, datasets
# from torch.utils.data import DataLoader, ImageFolder
# (2) 定义图片转换方式
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(400),
transforms.ToTensor()])
# (3) 定义数据库的路径
path = os.path.join(os.getcwd(), "userdataset", "train")
# (4) 从指定路径导入自定义的数据库
train_dataset = torch.utils.data.ImageFolder(root=path, transform=train_transforms)
# (5) 查看自定义的数据库
print(train_dataset.classes) # 根据文件夹的名称所确定的类别
print(train_dataset.class_to_idx) # 顺序索引 0,1,2...
print(train_dataset.imgs) # 元组,图片的路径和类别
# (6) 用 DataLoader 加载数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=64, shuffle=True)
imgs, labels = next(iter(train_loader)) # 创建生成器,用 next 返回一个批次的数据
print(imgs.shape) # torch.Size([64, 1, 28, 28])
print(labels.shape) # torch.Size([64])
# # (7) 传入网络模型进行训练
# for epoch in range(epochs):
# train_bar = tqdm(train_loader, file=sys.stdout)
# for step, data in enumerate(train_bar):
# ...
5.3 自定义数据集的封装
torch.utils.data.Dataset
是代表数据集的抽象类。我们创建自定义数据集,要继承 Dataset类并封装 其中的__getitem__()
和__len__()
方法。
__getitem__()
方法支持索引,可以使用 dataset[i] 返回数据集中的样本 i。__len__()
方法,可以使用 len(dataset) 返回数据集的大小。
定义 Dataset 类的基本结构如下:
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
下面以封装 FruitImagesDataset 数据集,示例创建自定义数据集的模板。
import os
import numpy as np
import cv2
import torch
import matplotlib.patches as patches
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from xml.etree import ElementTree as et
from torchvision import transforms as torchtrans
class FruitImagesDataset(torch.utils.data.Dataset):
def __init__(self, files_dir, width, height, transforms=None):
self.transforms = transforms
self.files_dir = files_dir
self.height = height
self.width = width
self.imgs = [image for image in sorted(os.listdir(files_dir))
if image[-4:] == '.jpg']
self.classes = ['_','apple', 'banana', 'orange']
def __getitem__(self, idx):
img_name = self.imgs[idx]
image_path = os.path.join(self.files_dir, img_name)
# reading the images and converting them to correct size and color
img = cv2.imread(image_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA)
# diving by 255
img_res /= 255.0
# annotation file
annot_filename = img_name[:-4] + '.xml'
annot_file_path = os.path.join(self.files_dir, annot_filename)
boxes = []
labels = []
tree = et.parse(annot_file_path)
root = tree.getroot()
# cv2 image gives size as height x width
wt = img.shape[1]
ht = img.shape[0]
# box coordinates for xml files are extracted and corrected for image size given
for member in root.findall('object'):
labels.append(self.classes.index(member.find('name').text))
# bounding box
xmin = int(member.find('bndbox').find('xmin').text)
xmax = int(member.find('bndbox').find('xmax').text)
ymin = int(member.find('bndbox').find('ymin').text)
ymax = int(member.find('bndbox').find('ymax').text)
xmin_corr = (xmin / wt) * self.width
xmax_corr = (xmax / wt) * self.width
ymin_corr = (ymin / ht) * self.height
ymax_corr = (ymax / ht) * self.height
boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr])
# convert boxes into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# getting the areas of the boxes
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["area"] = area
target["iscrowd"] = iscrowd
# image_id
image_id = torch.tensor([idx])
target["image_id"] = image_id
if self.transforms:
sample = self.transforms(image=img_res,
bboxes=target['boxes'],
labels=labels)
img_res = sample['image']
target['boxes'] = torch.Tensor(sample['bboxes'])
return img_res, target
def __len__(self):
return len(self.imgs)
def get_transform(train):
if train:
return A.Compose([
A.HorizontalFlip(0.5),
ToTensorV2(p=1.0)
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
else:
return A.Compose([
ToTensorV2(p=1.0)
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
files_dir = '../input/fruit-images-for-object-detection/train_zip/train'
test_dir = '../input/fruit-images-for-object-detection/test_zip/test'
dataset = FruitImagesDataset(train_dir, 480, 480)
【本节完】
版权声明:
欢迎关注『youcans的深度学习』系列,转发请注明原文链接:
【youcans的深度学习 08】PyTorch 数据加载和转换(https://youcans.blog.csdn.net/article/details/130217675)
Copyright 2023 youcans, XUPT
Crated:2023-04-18