import torch from torch.utils.data import sampler from torchvision import datasets from torch.utils.data import DataLoader from torch.utils.data import SubsetRandomSampler from torchvision import transforms class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): """ Parameters: ------------ tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: ------------ Tensor: Normalized image. """ for t, m, s in zip(tensor, self.mean, self.std): t.mul_(s).add_(m) return tensor def get_dataloaders_mnist(batch_size, num_workers=0, validation_fraction=None, train_transforms=None, test_transforms=None): if train_transforms is None: train_transforms = transforms.ToTensor() if test_transforms is None: test_transforms = transforms.ToTensor() train_dataset = datasets.MNIST(root='data', train=True, transform=train_transforms, download=True) valid_dataset = datasets.MNIST(root='data', train=True, transform=test_transforms) test_dataset = datasets.MNIST(root='data', train=False, transform=test_transforms) if validation_fraction is not None: num = int(validation_fraction * 60000) train_indices = torch.arange(0, 60000 - num) valid_indices = torch.arange(60000 - num, 60000) train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(valid_indices) valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, num_workers=num_workers, sampler=valid_sampler) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, sampler=train_sampler) else: train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) if validation_fraction is None: return train_loader, test_loader else: return train_loader, valid_loader, test_loader def get_dataloaders_cifar10(batch_size, num_workers=0, validation_fraction=None, train_transforms=None, test_transforms=None): if train_transforms is None: train_transforms = transforms.ToTensor() if test_transforms is None: test_transforms = transforms.ToTensor() train_dataset = datasets.CIFAR10(root='data', train=True, transform=train_transforms, download=True) valid_dataset = datasets.CIFAR10(root='data', train=True, transform=test_transforms) test_dataset = datasets.CIFAR10(root='data', train=False, transform=test_transforms) if validation_fraction is not None: num = int(validation_fraction * 50000) train_indices = torch.arange(0, 50000 - num) valid_indices = torch.arange(50000 - num, 50000) train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(valid_indices) valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, num_workers=num_workers, sampler=valid_sampler) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, sampler=train_sampler) else: train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=True) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) if validation_fraction is None: return train_loader, test_loader else: return train_loader, valid_loader, test_loader