import nibabel as nib
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, filepaths):
self.filepaths = filepaths
def __len__(self):
return len(self.filepaths)
def __getitem__(self, idx):
# Load image from file path
current_training_filepath = self.filepaths[idx]
data = nib.load(current_training_filepath)
image = data["image"]
label = data["label"]
# any on the fly transformations go here
...
# don't forget to transform the data into torch tensors and put it into the right shape
...
return image, label
training_dataloader = DataLoader(CustomDataset(filepaths_train), batch_size=32, shuffle=True)
validation_dataloader = DataLoader(CustomDataset(filepaths_validation), batch_size=8, shuffle=False)
test_dataloader = DataLoader(CustomDataset(filepaths_test), batch_size=8, shuffle=False)