dataloader.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from inversefed import consts
  2. import torch
  3. from torchvision import datasets, transforms
  4. class DataLoader:
  5. def __init__(self, data, device):
  6. self.data = data
  7. self.device = device
  8. def get_mean_std(self):
  9. if self.data == 'cifar10':
  10. mean, std = consts.cifar10_mean, consts.cifar10_std
  11. elif self.data == 'cifar100':
  12. mean, std = consts.cifar100_mean, consts.cifar100_std
  13. elif self.data == 'mnist':
  14. mean, std = consts.mnist_mean, consts.mnist_std
  15. elif self.data == 'imagenet':
  16. mean, std = consts.imagenet_mean, consts.imagenet_std
  17. else:
  18. raise Exception("dataset not found")
  19. return mean, std
  20. def get_data_info(self):
  21. mean, std = self.get_mean_std()
  22. transform = transforms.Compose([transforms.ToTensor(),
  23. transforms.Normalize(mean, std)])
  24. dm = torch.as_tensor(mean)[:, None, None].to(self.device)
  25. ds = torch.as_tensor(std)[:, None, None].to(self.device)
  26. data_root = '/Users/shellmiao/Documents/adversarial-attack-from-leakage/data/cifar_data'
  27. # data_root = '~/.torch'
  28. if self.data == 'cifar10':
  29. dataset = datasets.CIFAR10(root=data_root, download=False, train=False, transform=transform)
  30. elif self.data == 'cifar100':
  31. dataset = datasets.CIFAR100(root=data_root, download=True, train=False, transform=transform)
  32. elif self.data == 'mnist':
  33. dataset = datasets.MNIST(root=data_root, download=True, train=False, transform=transform)
  34. elif self.data == 'imagenet':
  35. dataset = datasets.ImageNet(root=data_root, download=True, train=False, transform=transform)
  36. else:
  37. raise Exception("dataset not found, load your own datasets")
  38. data_shape = dataset[0][0].shape
  39. classes = dataset.classes
  40. return dataset, data_shape, classes, (dm, ds)