data.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import importlib
  2. import json
  3. import logging
  4. import os
  5. from easyfl.datasets.dataset import FederatedTensorDataset
  6. from easyfl.datasets.utils.base_dataset import BaseDataset, CIFAR10, CIFAR100
  7. from easyfl.datasets.utils.util import load_dict
  8. logger = logging.getLogger(__name__)
  9. def read_dir(data_dir):
  10. clients = []
  11. groups = []
  12. data = {}
  13. files = os.listdir(data_dir)
  14. files = [f for f in files if f.endswith('.json')]
  15. for f in files:
  16. file_path = os.path.join(data_dir, f)
  17. with open(file_path, 'r') as inf:
  18. cdata = json.load(inf)
  19. clients.extend(cdata['users'])
  20. if 'hierarchies' in cdata:
  21. groups.extend(cdata['hierarchies'])
  22. data.update(cdata['user_data'])
  23. clients = list(sorted(data.keys()))
  24. return clients, groups, data
  25. def read_data(dataset_name, train_data_dir, test_data_dir):
  26. """Load datasets from data directories.
  27. Args:
  28. dataset_name (str): The name of the dataset.
  29. train_data_dir (str): The directory of training data.
  30. test_data_dir (str): The directory of testing data.
  31. Returns:
  32. list[str]: A list of client ids.
  33. list[str]: A list of group ids for dataset with hierarchies.
  34. dict: A dictionary of training data, e.g., {"id1": {"x": data, "y": label}, "id2": {"x": data, "y": label}}.
  35. dict: A dictionary of testing data. The format is same as training data for FEMNIST and Shakespeare datasets.
  36. For CIFAR datasets, the format is {"x": data, "y": label}, for centralized testing in the server.
  37. """
  38. if dataset_name == CIFAR10 or dataset_name == CIFAR100:
  39. train_data = load_dict(train_data_dir)
  40. test_data = load_dict(test_data_dir)
  41. return [], [], train_data, test_data
  42. # Data in the directories are `json` files with keys `users` and `user_data`.
  43. train_clients, train_groups, train_data = read_dir(train_data_dir)
  44. test_clients, test_groups, test_data = read_dir(test_data_dir)
  45. assert train_clients == test_clients
  46. assert train_groups == test_groups
  47. return train_clients, train_groups, train_data, test_data
  48. def load_data(root,
  49. dataset_name,
  50. num_of_clients,
  51. split_type,
  52. min_size,
  53. class_per_client,
  54. data_amount,
  55. iid_fraction,
  56. user,
  57. train_test_split,
  58. quantity_weights,
  59. alpha):
  60. """Simulate and load federated datasets.
  61. Args:
  62. root (str): The root directory where datasets stored.
  63. dataset_name (str): The name of the dataset. It currently supports: femnist, shakespeare, cifar10, and cifar100.
  64. Among them, femnist and shakespeare are adopted from LEAF benchmark.
  65. num_of_clients (int): The targeted number of clients to construct.
  66. split_type (str): The type of statistical simulation, options: iid, dir, and class.
  67. `iid` means independent and identically distributed data.
  68. `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
  69. `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
  70. `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
  71. min_size (int): The minimal number of samples in each client.
  72. It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
  73. class_per_client (int): The number of classes in each client. Only applicable when the split_type is 'class'.
  74. data_amount (float): The fraction of data sampled for LEAF datasets.
  75. e.g., 10% means that only 10% of total dataset size are used.
  76. iid_fraction (float): The fraction of the number of clients used when the split_type is 'iid'.
  77. user (bool): A flag to indicate whether partition users of the dataset into train-test groups.
  78. Only applicable to LEAF datasets.
  79. True means partitioning users of the dataset into train-test groups.
  80. False means partitioning each users' samples into train-test groups.
  81. train_test_split (float): The fraction of data for training; the rest are for testing.
  82. e.g., 0.9 means 90% of data are used for training and 10% are used for testing.
  83. quantity_weights (list[float]): The targeted distribution of quantities to simulate data quantity heterogeneity.
  84. The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
  85. The `num_of_clients` should be divisible by `len(weights)`.
  86. None means clients are simulated with the same data quantity.
  87. alpha (float): The parameter for Dirichlet distribution simulation, applicable only when split_type is `dir`.
  88. Returns:
  89. dict: A dictionary of training data, e.g., {"id1": {"x": data, "y": label}, "id2": {"x": data, "y": label}}.
  90. dict: A dictionary of testing data.
  91. function: A function to preprocess training data.
  92. function: A function to preprocess testing data.
  93. torchvision.transforms.transforms.Compose: Training data transformation.
  94. torchvision.transforms.transforms.Compose: Testing data transformation.
  95. """
  96. user_str = "user" if user else "sample"
  97. setting = BaseDataset.get_setting_folder(dataset_name, split_type, num_of_clients, min_size, class_per_client,
  98. data_amount, iid_fraction, user_str, train_test_split, alpha,
  99. quantity_weights)
  100. dir_path = os.path.dirname(os.path.realpath(__file__))
  101. dataset_file = os.path.join(dir_path, "data_process", "{}.py".format(dataset_name))
  102. if not os.path.exists(dataset_file):
  103. logger.error("Please specify a valid process file path for process_x and process_y functions.")
  104. dataset_path = "easyfl.datasets.data_process.{}".format(dataset_name)
  105. dataset_lib = importlib.import_module(dataset_path)
  106. process_x = getattr(dataset_lib, "process_x", None)
  107. process_y = getattr(dataset_lib, "process_y", None)
  108. transform_train = getattr(dataset_lib, "transform_train", None)
  109. transform_test = getattr(dataset_lib, "transform_test", None)
  110. data_dir = os.path.join(root, dataset_name)
  111. if not data_dir:
  112. os.makedirs(data_dir)
  113. train_data_dir = os.path.join(data_dir, setting, "train")
  114. test_data_dir = os.path.join(data_dir, setting, "test")
  115. if not os.path.exists(train_data_dir) or not os.path.exists(test_data_dir):
  116. dataset_class_path = "easyfl.datasets.{}.{}".format(dataset_name, dataset_name)
  117. dataset_class_lib = importlib.import_module(dataset_class_path)
  118. class_name = dataset_name.capitalize()
  119. dataset = getattr(dataset_class_lib, class_name)(root=data_dir,
  120. fraction=data_amount,
  121. split_type=split_type,
  122. user=user,
  123. iid_user_fraction=iid_fraction,
  124. train_test_split=train_test_split,
  125. minsample=min_size,
  126. num_of_client=num_of_clients,
  127. class_per_client=class_per_client,
  128. setting_folder=setting,
  129. alpha=alpha,
  130. weights=quantity_weights)
  131. try:
  132. filename = f"{setting}.zip"
  133. dataset.download_packaged_dataset_and_extract(filename)
  134. logger.info(f"Downloaded packaged dataset {dataset_name}: {filename}")
  135. except Exception as e:
  136. logger.info(f"Failed to download packaged dataset: {e.args}")
  137. # CIFAR10 generate data in setup() stage, LEAF related datasets generate data in sampling()
  138. if not os.path.exists(train_data_dir):
  139. dataset.setup()
  140. if not os.path.exists(train_data_dir):
  141. dataset.sampling()
  142. users, train_groups, train_data, test_data = read_data(dataset_name, train_data_dir, test_data_dir)
  143. return train_data, test_data, process_x, process_y, transform_train, transform_test
  144. def construct_datasets(root,
  145. dataset_name,
  146. num_of_clients,
  147. split_type,
  148. min_size,
  149. class_per_client,
  150. data_amount,
  151. iid_fraction,
  152. user,
  153. train_test_split,
  154. quantity_weights,
  155. alpha):
  156. """Construct and load provided federated learning datasets.
  157. Args:
  158. root (str): The root directory where datasets stored.
  159. dataset_name (str): The name of the dataset. It currently supports: femnist, shakespeare, cifar10, and cifar100.
  160. Among them, femnist and shakespeare are adopted from LEAF benchmark.
  161. num_of_clients (int): The targeted number of clients to construct.
  162. split_type (str): The type of statistical simulation, options: iid, dir, and class.
  163. `iid` means independent and identically distributed data.
  164. `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
  165. `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
  166. `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
  167. min_size (int): The minimal number of samples in each client.
  168. It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
  169. class_per_client (int): The number of classes in each client. Only applicable when the split_type is 'class'.
  170. data_amount (float): The fraction of data sampled for LEAF datasets.
  171. e.g., 10% means that only 10% of total dataset size are used.
  172. iid_fraction (float): The fraction of the number of clients used when the split_type is 'iid'.
  173. user (bool): A flag to indicate whether partition users of the dataset into train-test groups.
  174. Only applicable to LEAF datasets.
  175. True means partitioning users of the dataset into train-test groups.
  176. False means partitioning each users' samples into train-test groups.
  177. train_test_split (float): The fraction of data for training; the rest are for testing.
  178. e.g., 0.9 means 90% of data are used for training and 10% are used for testing.
  179. quantity_weights (list[float]): The targeted distribution of quantities to simulate data quantity heterogeneity.
  180. The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
  181. The `num_of_clients` should be divisible by `len(weights)`.
  182. None means clients are simulated with the same data quantity.
  183. alpha (float): The parameter for Dirichlet distribution simulation, applicable only when split_type is `dir`.
  184. Returns:
  185. :obj:`FederatedDataset`: Training dataset.
  186. :obj:`FederatedDataset`: Testing dataset.
  187. """
  188. train_data, test_data, process_x, process_y, transform_train, transform_test = load_data(root,
  189. dataset_name,
  190. num_of_clients,
  191. split_type,
  192. min_size,
  193. class_per_client,
  194. data_amount,
  195. iid_fraction,
  196. user,
  197. train_test_split,
  198. quantity_weights,
  199. alpha)
  200. # CIFAR datasets are simulated.
  201. test_simulated = True
  202. if dataset_name == CIFAR10 or dataset_name == CIFAR100:
  203. test_simulated = False
  204. train_data = FederatedTensorDataset(train_data,
  205. simulated=True,
  206. do_simulate=False,
  207. process_x=process_x,
  208. process_y=process_y,
  209. transform=transform_train)
  210. test_data = FederatedTensorDataset(test_data,
  211. simulated=test_simulated,
  212. do_simulate=False,
  213. process_x=process_x,
  214. process_y=process_y,
  215. transform=transform_test)
  216. return train_data, test_data