12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import logging
- import os
- import torchvision
- from easyfl.datasets.simulation import data_simulation
- from easyfl.datasets.utils.base_dataset import BaseDataset, CIFAR100
- from easyfl.datasets.utils.util import save_dict
- logger = logging.getLogger(__name__)
- class Cifar100(BaseDataset):
- def __init__(self,
- root,
- fraction,
- split_type,
- user,
- iid_user_fraction=0.1,
- train_test_split=0.9,
- minsample=10,
- num_class=80,
- num_of_client=100,
- class_per_client=2,
- setting_folder=None,
- seed=-1,
- weights=None,
- alpha=0.5):
- super(Cifar100, self).__init__(root,
- CIFAR100,
- fraction,
- split_type,
- user,
- iid_user_fraction,
- train_test_split,
- minsample,
- num_class,
- num_of_client,
- class_per_client,
- setting_folder,
- seed)
- self.train_data, self.test_data = {}, {}
- self.split_type = split_type
- self.num_of_client = num_of_client
- self.weights = weights
- self.alpha = alpha
- self.min_size = minsample
- self.class_per_client = class_per_client
- def download_packaged_dataset_and_extract(self, filename):
- pass
- def download_raw_file_and_extract(self):
- train_set = torchvision.datasets.CIFAR100(root=self.base_folder, train=True, download=True)
- test_set = torchvision.datasets.CIFAR100(root=self.base_folder, train=False, download=True)
- self.train_data = {
- 'x': train_set.data,
- 'y': train_set.targets
- }
- self.test_data = {
- 'x': test_set.data,
- 'y': test_set.targets
- }
- def preprocess(self):
- train_data_path = os.path.join(self.data_folder, "train")
- test_data_path = os.path.join(self.data_folder, "test")
- if not os.path.exists(self.data_folder):
- os.makedirs(self.data_folder)
- if self.weights is None and os.path.exists(train_data_path):
- return
- logger.info("Start CIFAR10 data simulation")
- _, train_data = data_simulation(self.train_data['x'],
- self.train_data['y'],
- self.num_of_client,
- self.split_type,
- self.weights,
- self.alpha,
- self.min_size,
- self.class_per_client)
- logger.info("Complete CIFAR10 data simulation")
- save_dict(train_data, train_data_path)
- save_dict(self.test_data, test_data_path)
- def convert_data_to_json(self):
- pass
|