The full guide to creating custom datasets and dataloaders for different models in PyTorch
Before you can build a machine learning model, you need to load your data into a dataset. Luckily, PyTorch has many commands to help with this entire process (if you are not familiar with PyTorch I recommend refreshing on the basics here).
PyTorch has good documentation to help with this process, but I have not found any comprehensive documentation or tutorials towards custom datasets. I’m first going to start with creating basic premade datasets and then work my way up to creating datasets from scratch for different models!
What is a Dataset and Dataloader?
Before we dive into code for different use cases, let’s understand the difference between the two terms. Generally, you first create your dataset and then create a dataloader. A dataset contains the features and labels from each data point that will be fed into the model. A dataloader is a custom PyTorch iterable that makes it easy to load data with added features.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
The most common arguments in the dataloader are batch_size, shuffle (usually only for the training data), num_workers (to multi-process loading the data), and pin_memory (to put the fetched data Tensors in pinned memory and enable faster data transfer to CUDA-enabled GPUs).
It is recommended to set pin_memory = True instead of specifying num_workers due to multiprocessing complications with CUDA.
Loading a Premade Dataset
In the case that your dataset is downloaded from online or locally, it will be extremely simple to create the dataset. I think PyTorch has good documentation on this, so I will be brief.
If you know the dataset is either from PyTorch or PyTorch-compatible, simply call the necessary imports and the dataset of choice:
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms imports ToTensor
data = torchvision.datasets.CIFAR10('path', train=True, transform=ToTensor())
Each dataset will have unique arguments to pass into it (found here). In general, it will be the path the dataset is stored at, a boolean indicating if it needs to be downloaded or not (conveniently called download), whether it is training or testing, and if transforms need to be applied.
Transforms
I dropped in that transforms can be applied to a dataset at the end of the last section, but what actually is a transform?
A transform is a method of manipulating data for preprocessing an image. There are many different facets to transforms. The most common transform, ToTensor(), will convert the dataset to tensors (needed to input into any model). Other transforms built into PyTorch (torchvision.transforms) include flipping, rotating, cropping, normalizing, and shifting images. These are typically used so the model can generalize better and doesn’t overfit to the training data. Data augmentations can also be used to artificially increase the size of the dataset if needed.
Beware most torchvision transforms only accept Pillow image or tensor formats (not numpy). To convert, simply use
To convert from numpy, either create a torch tensor or use the following:
From PIL import Image
# assume arr is a numpy array
# you may need to normalize and cast arr to np.uint8 depending on format
img = Image.fromarray(arr)
Transforms can be applied simultaneously using torchvision.transforms.compose. You can combine as many transforms as needed for the dataset. An example is shown below:
import torchvision.transforms.Compose
dataset_transform = transforms.Compose([
transforms.RandomResizedCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Be sure to pass the saved transform as an argument into the dataset for it to be applied in the dataloader.
Creating a Custom Dataset
In most cases of developing your own model, you will need a custom dataset. A common use case would be transfer learning to apply your own dataset on a pretrained model.
There are 3 required parts to a PyTorch dataset class: initialization, length, and retrieving an element.
__init__: To initialize the dataset, pass in the raw and labeled data. The best practice is to pass in the raw image data and labeled data separately.
__len__: Return the length of the dataset. Before creating the dataset, the raw and labeled data should be checked to be the same size.
__getitem__: This is where all the data handling occurs to return a given index (idx) of the raw and labeled data. If any transforms need to be applied, the data must be converted to a tensor and transformed. If the initialization contained a path to the dataset, the path must be opened and data accessed/preprocessed before it can be returned.
Example dataset for a semantic segmentation model:
from torch.utils.data import Dataset
from torchvision import transforms
class ExampleDataset(Dataset):
"""Example dataset"""
def __init__(self, raw_img, data_mask, transform=None):
self.raw_img = raw_img
self.data_mask = data_mask
self.transform = transform
def __len__(self):
return len(self.raw_img)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.raw_img[idx]
mask = self.data_mask[idx]
sample = {'image': image, 'mask': mask}
if self.transform:
sample = self.transform(sample)
return sample
It is important to look at the input of the first layer of the model (especially for a pretrained model), to make sure the shape of the data matches the input shape. If not, you may need to adjust the dimensions. This is common if the input image is a greyscale n x n array, but the model requires a channel dimension (1 x 256 x 256).
After the dataset and dataloader are applied, the format of the data should be NCHW (batch size, channel size, height, width). Reformatting can be done in the __getitem__ method before outputting to the model.
Splitting the Dataset
While creating the dataset, you may want to split into a training, testing, and validation dataset. This can be done using a built-in PyTorch function and specifying the sizes. Make sure the dataset splits add up to the total length of the dataset.
from torch.utils.data import random_split
train, val, test = random_split(dataset, [train_size, val_size, test_size])
Data Labels
There can be different data labels depending on the model: classification, object detection, or segmentation. A model classification label will contain a class label if it is multiclass or a binary number if it is binary. An object detection model will contain a bounding box of coordinates as the label. A semantic segmentation model will contain a binary mask matching the size of the raw image data. An instance segmentation contains all mask data in the raw image data.
Creating a dataset is a foundational aspect of model development. By having a faulty dataset, there will be many errors downstream in training or evaluating the model. The most common errors to watch out for are shape or type mismatches. By following this and referring to PyTorch docs, you should have a working dataset!
References
- Datasets & DataLoaders – PyTorch Tutorials 2.3.0+cu121 documentation
- Writing Custom Datasets, DataLoaders and Transforms – PyTorch Tutorials 2.3.0+cu121 documentation
- Transforming and augmenting images – Torchvision 0.18 documentation
- Compose – Torchvision main documentation
Comprehensive Guide to Datasets and Dataloaders in PyTorch was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
Originally appeared here:
Comprehensive Guide to Datasets and Dataloaders in PyTorch
Go Here to Read this Fast! Comprehensive Guide to Datasets and Dataloaders in PyTorch