您现在的位置是:首页 >其他 >pytorch笔记:Dataset 和 DataLoader网站首页其他
pytorch笔记:Dataset 和 DataLoader
简介pytorch笔记:Dataset 和 DataLoader
- A custom Dataset class must implement three functions: __init__, __len__, and __getitem__.
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
- The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.
DataLoader 介绍,源码:
- collate_fn 是针对 minibatches 的操作,Dataset 的 transform 是针对单个样本的处理
- 一般的 Dataset 类型都是map-style datasets
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be anyIterable
with__len__
implemented. If specified, :attr:shuffle
must not be specified.
batch_sampler (Sampler or Iterable, optional): like :attr:
sampler
, but
returns a batch of indices at a time. Mutually exclusive with
:attr:batch_size
, :attr:shuffle
, :attr:sampler
,
and :attr:drop_last
.
- 不设置 sampler 参数,会有默认的 sampler 处理
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
else:
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
- RandomSampler 中 randperm 返回随机排序,yield from 介绍
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
SequentialSampler 原序返回
return iter(range(len(self.data_source)))
- 不设置 batch_sampler 参数,会有默认的 batch_sampler 处理,它根据 sampler 采样组成一个 batch 后返回
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
- 不设置 collate_fn 参数,一般也没有 batch_sampler 的情况下,调用默认的 default_collate,以 batch 为参数,基本没有做任何事
@property
def _auto_collation(self):
return self.batch_sampler is not None
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
- 视频里讲了_index_sampler 和 _get_iterator 的相关内容
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。