tensor_data_loader

class snippets.tensor_data_loader.TensorDataLoader(dataset: TensorDataset, batch_size: int = 1, shuffle: bool = False)

Fast data loader for torch tensor datasets, fusing torch.utils.data.TensorDataset and torch.utils.data.DataLoader.

torch.utils.data.DataLoader is slow for torch.utils.data.TensorDatasets because it iterates over elements of the dataset. Tensor-specific slicing implemented by TensorDataLoader is typically orders of magnitude faster for data that fit in memory.

Parameters:
  • dataset – Dataset to load from.

  • batch_size – Number of samples per batch.

  • shuffle – Shuffle dataset before batching.

Example

>>> from snippets.tensor_data_loader import TensorDataLoader
>>> import torch
>>> from torch.utils.data import TensorDataset
>>> tensors = torch.randn(13, 5), torch.randn(13)
>>> dataset = TensorDataset(*tensors)
>>> loader = TensorDataLoader(dataset, batch_size=7)
>>> for X, y in loader:
...     X.shape, y.shape
(torch.Size([7, 5]), torch.Size([7]))
(torch.Size([6, 5]), torch.Size([6]))