One minute
How to Create a Custom Pytorch Dataloader
First, create a custom dataset class.
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, features, labels):
assert len(features) == len(labels)
self.features = features
self.labels = labels
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
Next, create a custom dataloader where we specify the batch size.
features, labels = load_data()
# features & labels must have equal lengths
# e.g. features = [[1,2,3],[4,5,6]]
# labels = [7,8]
dataset = CustomDataset(features, labels)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True)
Finally, iterate over the dataloader during training.
for epoch in range(num_epochs):
for x, y in train_dataloader:
# do stuff
Read other posts