123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- import logging
- import os
- from abc import abstractmethod
- from easyfl.datasets.utils.remove_users import remove
- from easyfl.datasets.utils.sample import sample, extreme
- from easyfl.datasets.utils.split_data import split_train_test
- logger = logging.getLogger(__name__)
- CIFAR10 = "cifar10"
- CIFAR100 = "cifar100"
- class BaseDataset(object):
- """The internal base dataset implementation.
- Args:
- root (str): The root directory where datasets stored.
- dataset_name (str): The name of the dataset.
- fraction (float): The fraction of the data chosen from the raw data to use.
- num_of_clients (int): The targeted number of clients to construct.
- split_type (str): The type of statistical simulation, options: iid, dir, and class.
- `iid` means independent and identically distributed data.
- `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
- `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
- `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
- minsample (int): The minimal number of samples in each client.
- It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
- class_per_client (int): The number of classes in each client. Only applicable when the split_type is 'class'.
- iid_user_fraction (float): The fraction of the number of clients used when the split_type is 'iid'.
- user (bool): A flag to indicate whether partition users of the dataset into train-test groups.
- Only applicable to LEAF datasets.
- True means partitioning users of the dataset into train-test groups.
- False means partitioning each users' samples into train-test groups.
- train_test_split (float): The fraction of data for training; the rest are for testing.
- e.g., 0.9 means 90% of data are used for training and 10% are used for testing.
- num_class: The number of classes in this dataset.
- seed: Random seed.
- """
- def __init__(self,
- root,
- dataset_name,
- fraction,
- split_type,
- user,
- iid_user_fraction,
- train_test_split,
- minsample,
- num_class,
- num_of_client,
- class_per_client,
- setting_folder,
- seed=-1,
- **kwargs):
-
-
- self.base_folder = root
- self.dataset_name = dataset_name
- self.fraction = fraction
- self.split_type = split_type
- self.user = user
- self.iid_user_fraction = iid_user_fraction
- self.train_test_split = train_test_split
- self.minsample = minsample
- self.num_class = num_class
- self.num_of_client = num_of_client
- self.class_per_client = class_per_client
- self.seed = seed
- if split_type == "iid":
- assert self.user == False
- self.iid = True
- elif split_type == "niid":
-
- self.iid = False
- self.setting_folder = setting_folder
- self.data_folder = os.path.join(self.base_folder, self.setting_folder)
- @abstractmethod
- def download_packaged_dataset_and_extract(self, filename):
- raise NotImplementedError("download_packaged_dataset_and_extract not implemented")
- @abstractmethod
- def download_raw_file_and_extract(self):
- raise NotImplementedError("download_raw_file_and_extract not implemented")
- @abstractmethod
- def preprocess(self):
- raise NotImplementedError("preprocess not implemented")
- @abstractmethod
- def convert_data_to_json(self):
- raise NotImplementedError("convert_data_to_json not implemented")
- @staticmethod
- def get_setting_folder(dataset, split_type, num_of_client, min_size, class_per_client,
- fraction, iid_fraction, user_str, train_test_split, alpha=None, weights=None):
- if dataset == CIFAR10 or dataset == CIFAR100:
- return "{}_{}_{}_{}_{}_{}_{}".format(dataset, split_type, num_of_client, min_size, class_per_client, alpha,
- 1 if weights else 0)
- else:
- return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(dataset, split_type, num_of_client, min_size, class_per_client,
- fraction, iid_fraction, user_str, train_test_split)
- def setup(self):
- self.download_raw_file_and_extract()
- self.preprocess()
- self.convert_data_to_json()
- def sample_customized(self):
- meta_folder = os.path.join(self.base_folder, "meta")
- if not os.path.exists(meta_folder):
- os.makedirs(meta_folder)
- sample_folder = os.path.join(self.data_folder, "sampled_data")
- if not os.path.exists(sample_folder):
- os.makedirs(sample_folder)
- if not os.listdir(sample_folder):
- sample(self.base_folder, self.data_folder, meta_folder, self.fraction, self.iid, self.iid_user_fraction, self.seed)
- def sample_extreme(self):
- meta_folder = os.path.join(self.base_folder, "meta")
- if not os.path.exists(meta_folder):
- os.makedirs(meta_folder)
- sample_folder = os.path.join(self.data_folder, "sampled_data")
- if not os.path.exists(sample_folder):
- os.makedirs(sample_folder)
- if not os.listdir(sample_folder):
- extreme(self.base_folder, self.data_folder, meta_folder, self.fraction, self.num_class, self.num_of_client, self.class_per_client, self.seed)
- def remove_unqualified_user(self):
- rm_folder = os.path.join(self.data_folder, "rem_user_data")
- if not os.path.exists(rm_folder):
- os.makedirs(rm_folder)
- if not os.listdir(rm_folder):
- remove(self.data_folder, self.dataset_name, self.minsample)
- def split_train_test_set(self):
- meta_folder = os.path.join(self.base_folder, "meta")
- train = os.path.join(self.data_folder, "train")
- if not os.path.exists(train):
- os.makedirs(train)
- test = os.path.join(self.data_folder, "test")
- if not os.path.exists(test):
- os.makedirs(test)
- if not os.listdir(train) and not os.listdir(test):
- split_train_test(self.data_folder, meta_folder, self.dataset_name, self.user, self.train_test_split, self.seed)
- def sampling(self):
- if self.split_type == "iid":
- self.sample_customized()
- elif self.split_type == "niid":
- self.sample_customized()
- elif self.split_type == "class":
- self.sample_extreme()
- self.remove_unqualified_user()
- self.split_train_test_set()
|