eval_dataset.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. from torchvision import datasets
  3. from dataset import get_semi_supervised_dataset
  4. from easyfl.datasets.data import CIFAR100
  5. from transform import SimCLRTransform
  6. def get_data_loaders(dataset, image_size=32, batch_size=512, num_workers=8):
  7. transformation = SimCLRTransform(size=image_size, gaussian=False).test_transform
  8. if dataset == CIFAR100:
  9. data_path = "./data/cifar100"
  10. train_dataset = datasets.CIFAR100(data_path, download=True, transform=transformation)
  11. test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
  12. else:
  13. data_path = "./data/cifar10"
  14. train_dataset = datasets.CIFAR10(data_path, download=True, transform=transformation)
  15. test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
  16. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
  17. test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
  18. return train_loader, test_loader
  19. def get_semi_supervised_data_loaders(dataset, data_distribution, class_per_client, label_ratio, batch_size=512, num_workers=8, image_size=32):
  20. transformation = SimCLRTransform(size=image_size, gaussian=False).test_transform
  21. if dataset == CIFAR100:
  22. data_path = "./data/cifar100"
  23. test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
  24. else:
  25. data_path = "./data/cifar10"
  26. test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
  27. test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
  28. _, _, labeled_data = get_semi_supervised_dataset(dataset, 5, data_distribution, class_per_client, label_ratio)
  29. return labeled_data.loader(batch_size), test_loader