123456789101112131415161718192021222324252627282930313233343536373839 |
- import torch
- from torchvision import datasets
- from dataset import get_semi_supervised_dataset
- from easyfl.datasets.data import CIFAR100
- from transform import SimCLRTransform
- def get_data_loaders(dataset, image_size=32, batch_size=512, num_workers=8):
- transformation = SimCLRTransform(size=image_size, gaussian=False).test_transform
- if dataset == CIFAR100:
- data_path = "./data/cifar100"
- train_dataset = datasets.CIFAR100(data_path, download=True, transform=transformation)
- test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
- else:
- data_path = "./data/cifar10"
- train_dataset = datasets.CIFAR10(data_path, download=True, transform=transformation)
- test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
- test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
- return train_loader, test_loader
- def get_semi_supervised_data_loaders(dataset, data_distribution, class_per_client, label_ratio, batch_size=512, num_workers=8, image_size=32):
- transformation = SimCLRTransform(size=image_size, gaussian=False).test_transform
- if dataset == CIFAR100:
- data_path = "./data/cifar100"
- test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
- else:
- data_path = "./data/cifar10"
- test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
- test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
- _, _, labeled_data = get_semi_supervised_dataset(dataset, 5, data_distribution, class_per_client, label_ratio)
- return labeled_data.loader(batch_size), test_loader
|