您现在的位置是:首页 >其他 >pytorch笔记:Dataset 和 DataLoader网站首页其他

pytorch笔记:Dataset 和 DataLoader

_森罗万象 2024-08-27 12:01:03
简介pytorch笔记:Dataset 和 DataLoader

来自B站视频官网教程API查阅

  • A custom Dataset class must implement three functions: __init__, __len__, and __getitem__.
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

DataLoader 介绍,源码:

  • collate_fn 是针对 minibatches 的操作,Dataset 的 transform 是针对单个样本的处理
  • 一般的 Dataset 类型都是map-style datasets

sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any Iterable with __len__
implemented. If specified, :attr:shuffle must not be specified.

batch_sampler (Sampler or Iterable, optional): like :attr:sampler, but
returns a batch of indices at a time. Mutually exclusive with
:attr:batch_size, :attr:shuffle, :attr:sampler,
and :attr:drop_last.

  • 不设置 sampler 参数,会有默认的 sampler 处理
 if sampler is None:  # give default samplers
     if self._dataset_kind == _DatasetKind.Iterable:
         # See NOTE [ Custom Samplers and IterableDataset ]
         sampler = _InfiniteConstantSampler()
     else:  # map-style
         if shuffle:
             sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
         else:
             sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

SequentialSampler 原序返回

 return iter(range(len(self.data_source)))
  • 不设置 batch_sampler 参数,会有默认的 batch_sampler 处理,它根据 sampler 采样组成一个 batch 后返回
if batch_size is not None and batch_sampler is None:
    # auto_collation without custom batch_sampler
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  • 不设置 collate_fn 参数,一般也没有 batch_sampler 的情况下,调用默认的 default_collate,以 batch 为参数,基本没有做任何事
@property
def _auto_collation(self):
    return self.batch_sampler is not None
    
if collate_fn is None:
    if self._auto_collation:
        collate_fn = _utils.collate.default_collate
    else:
        collate_fn = _utils.collate.default_convert
  • 视频里讲了_index_sampler 和 _get_iterator 的相关内容
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。