base_dataset.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import logging
  2. import os
  3. from abc import abstractmethod
  4. from easyfl.datasets.utils.remove_users import remove
  5. from easyfl.datasets.utils.sample import sample, extreme
  6. from easyfl.datasets.utils.split_data import split_train_test
  7. logger = logging.getLogger(__name__)
  8. CIFAR10 = "cifar10"
  9. CIFAR100 = "cifar100"
  10. class BaseDataset(object):
  11. """The internal base dataset implementation.
  12. Args:
  13. root (str): The root directory where datasets stored.
  14. dataset_name (str): The name of the dataset.
  15. fraction (float): The fraction of the data chosen from the raw data to use.
  16. num_of_clients (int): The targeted number of clients to construct.
  17. split_type (str): The type of statistical simulation, options: iid, dir, and class.
  18. `iid` means independent and identically distributed data.
  19. `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
  20. `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
  21. `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
  22. minsample (int): The minimal number of samples in each client.
  23. It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
  24. class_per_client (int): The number of classes in each client. Only applicable when the split_type is 'class'.
  25. iid_user_fraction (float): The fraction of the number of clients used when the split_type is 'iid'.
  26. user (bool): A flag to indicate whether partition users of the dataset into train-test groups.
  27. Only applicable to LEAF datasets.
  28. True means partitioning users of the dataset into train-test groups.
  29. False means partitioning each users' samples into train-test groups.
  30. train_test_split (float): The fraction of data for training; the rest are for testing.
  31. e.g., 0.9 means 90% of data are used for training and 10% are used for testing.
  32. num_class: The number of classes in this dataset.
  33. seed: Random seed.
  34. """
  35. def __init__(self,
  36. root,
  37. dataset_name,
  38. fraction,
  39. split_type,
  40. user,
  41. iid_user_fraction,
  42. train_test_split,
  43. minsample,
  44. num_class,
  45. num_of_client,
  46. class_per_client,
  47. setting_folder,
  48. seed=-1,
  49. **kwargs):
  50. # file_path = os.path.dirname(os.path.realpath(__file__))
  51. # self.base_folder = os.path.join(os.path.dirname(file_path), "data", dataset_name)
  52. self.base_folder = root
  53. self.dataset_name = dataset_name
  54. self.fraction = fraction
  55. self.split_type = split_type # iid, niid, class
  56. self.user = user
  57. self.iid_user_fraction = iid_user_fraction
  58. self.train_test_split = train_test_split
  59. self.minsample = minsample
  60. self.num_class = num_class
  61. self.num_of_client = num_of_client
  62. self.class_per_client = class_per_client
  63. self.seed = seed
  64. if split_type == "iid":
  65. assert self.user == False
  66. self.iid = True
  67. elif split_type == "niid":
  68. # if niid, user can be either True or False
  69. self.iid = False
  70. self.setting_folder = setting_folder
  71. self.data_folder = os.path.join(self.base_folder, self.setting_folder)
  72. @abstractmethod
  73. def download_packaged_dataset_and_extract(self, filename):
  74. raise NotImplementedError("download_packaged_dataset_and_extract not implemented")
  75. @abstractmethod
  76. def download_raw_file_and_extract(self):
  77. raise NotImplementedError("download_raw_file_and_extract not implemented")
  78. @abstractmethod
  79. def preprocess(self):
  80. raise NotImplementedError("preprocess not implemented")
  81. @abstractmethod
  82. def convert_data_to_json(self):
  83. raise NotImplementedError("convert_data_to_json not implemented")
  84. @staticmethod
  85. def get_setting_folder(dataset, split_type, num_of_client, min_size, class_per_client,
  86. fraction, iid_fraction, user_str, train_test_split, alpha=None, weights=None):
  87. if dataset == CIFAR10 or dataset == CIFAR100:
  88. return "{}_{}_{}_{}_{}_{}_{}".format(dataset, split_type, num_of_client, min_size, class_per_client, alpha,
  89. 1 if weights else 0)
  90. else:
  91. return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(dataset, split_type, num_of_client, min_size, class_per_client,
  92. fraction, iid_fraction, user_str, train_test_split)
  93. def setup(self):
  94. self.download_raw_file_and_extract()
  95. self.preprocess()
  96. self.convert_data_to_json()
  97. def sample_customized(self):
  98. meta_folder = os.path.join(self.base_folder, "meta")
  99. if not os.path.exists(meta_folder):
  100. os.makedirs(meta_folder)
  101. sample_folder = os.path.join(self.data_folder, "sampled_data")
  102. if not os.path.exists(sample_folder):
  103. os.makedirs(sample_folder)
  104. if not os.listdir(sample_folder):
  105. sample(self.base_folder, self.data_folder, meta_folder, self.fraction, self.iid, self.iid_user_fraction, self.seed)
  106. def sample_extreme(self):
  107. meta_folder = os.path.join(self.base_folder, "meta")
  108. if not os.path.exists(meta_folder):
  109. os.makedirs(meta_folder)
  110. sample_folder = os.path.join(self.data_folder, "sampled_data")
  111. if not os.path.exists(sample_folder):
  112. os.makedirs(sample_folder)
  113. if not os.listdir(sample_folder):
  114. extreme(self.base_folder, self.data_folder, meta_folder, self.fraction, self.num_class, self.num_of_client, self.class_per_client, self.seed)
  115. def remove_unqualified_user(self):
  116. rm_folder = os.path.join(self.data_folder, "rem_user_data")
  117. if not os.path.exists(rm_folder):
  118. os.makedirs(rm_folder)
  119. if not os.listdir(rm_folder):
  120. remove(self.data_folder, self.dataset_name, self.minsample)
  121. def split_train_test_set(self):
  122. meta_folder = os.path.join(self.base_folder, "meta")
  123. train = os.path.join(self.data_folder, "train")
  124. if not os.path.exists(train):
  125. os.makedirs(train)
  126. test = os.path.join(self.data_folder, "test")
  127. if not os.path.exists(test):
  128. os.makedirs(test)
  129. if not os.listdir(train) and not os.listdir(test):
  130. split_train_test(self.data_folder, meta_folder, self.dataset_name, self.user, self.train_test_split, self.seed)
  131. def sampling(self):
  132. if self.split_type == "iid":
  133. self.sample_customized()
  134. elif self.split_type == "niid":
  135. self.sample_customized()
  136. elif self.split_type == "class":
  137. self.sample_extreme()
  138. self.remove_unqualified_user()
  139. self.split_train_test_set()