tensor_data_loader¶
- class snippets.tensor_data_loader.TensorDataLoader(dataset: TensorDataset, batch_size: int = 1, shuffle: bool = False)¶
Fast data loader for torch tensor datasets, fusing
torch.utils.data.TensorDatasetandtorch.utils.data.DataLoader.torch.utils.data.DataLoaderis slow fortorch.utils.data.TensorDatasets because it iterates over elements of the dataset. Tensor-specific slicing implemented byTensorDataLoaderis typically orders of magnitude faster for data that fit in memory.- Parameters:
dataset – Dataset to load from.
batch_size – Number of samples per batch.
shuffle – Shuffle dataset before batching.
Example
>>> from snippets.tensor_data_loader import TensorDataLoader >>> import torch >>> from torch.utils.data import TensorDataset
>>> tensors = torch.randn(13, 5), torch.randn(13) >>> dataset = TensorDataset(*tensors) >>> loader = TensorDataLoader(dataset, batch_size=7) >>> for X, y in loader: ... X.shape, y.shape (torch.Size([7, 5]), torch.Size([7])) (torch.Size([6, 5]), torch.Size([6]))