dataset.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. import logging
  2. import os
  3. from abc import ABC, abstractmethod
  4. import numpy as np
  5. import torch
  6. from torch.utils.data import TensorDataset, DataLoader
  7. from torchvision.datasets.folder import default_loader, make_dataset
  8. from easyfl.datasets.dataset_util import TransformDataset, ImageDataset
  9. from easyfl.datasets.simulation import data_simulation, SIMULATE_IID
  10. logger = logging.getLogger(__name__)
  11. TEST_IN_SERVER = "test_in_server"
  12. TEST_IN_CLIENT = "test_in_client"
  13. IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
  14. DEFAULT_MERGED_ID = "Merged"
  15. def default_process_x(raw_x_batch):
  16. return torch.tensor(raw_x_batch)
  17. def default_process_y(raw_y_batch):
  18. return torch.tensor(raw_y_batch)
  19. class FederatedDataset(ABC):
  20. """The abstract class of federated dataset for EasyFL."""
  21. def __init__(self):
  22. pass
  23. @abstractmethod
  24. def loader(self, batch_size, shuffle=True):
  25. """Get data loader.
  26. Args:
  27. batch_size (int): The batch size of the data loader.
  28. shuffle (bool): Whether shuffle the data in the loader.
  29. """
  30. raise NotImplementedError("Data loader not implemented")
  31. @abstractmethod
  32. def size(self, cid):
  33. """Get dataset size.
  34. Args:
  35. cid (str): client id.
  36. """
  37. raise NotImplementedError("Size not implemented")
  38. @property
  39. def users(self):
  40. """Get client ids of the federated dataset."""
  41. raise NotImplementedError("Users not implemented")
  42. class FederatedTensorDataset(FederatedDataset):
  43. """Federated tensor dataset, data of clients are in format of tensor or list.
  44. Args:
  45. data (dict): A dictionary of data, e.g., {"id1": {"x": [[], [], ...], "y": [...]]}}.
  46. If simulation is not done previously, it is in format of {'x':[[],[], ...], 'y': [...]}.
  47. transform (torchvision.transforms.transforms.Compose, optional): Transformation for data.
  48. target_transform (torchvision.transforms.transforms.Compose, optional): Transformation for data labels.
  49. process_x (function, optional): A function to preprocess training data.
  50. process_y (function, optional): A function to preprocess testing data.
  51. simulated (bool, optional): Whether the dataset is simulated to federated learning settings.
  52. do_simulate (bool, optional): Whether conduct simulation. It is only effective if it is not simulated.
  53. num_of_clients (int, optional): number of clients for simulation. Only need if doing simulation.
  54. simulation_method(optional): split method. Only need if doing simulation.
  55. weights (list[float], optional): The targeted distribution of quantities to simulate quantity heterogeneity.
  56. The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
  57. The `num_of_clients` should be divisible by `len(weights)`.
  58. None means clients are simulated with the same data quantity.
  59. alpha (float, optional): The parameter for Dirichlet distribution simulation, only for dir simulation.
  60. min_size (int, optional): The minimal number of samples in each client, only for dir simulation.
  61. class_per_client (int, optional): The number of classes in each client, only for non-iid by class simulation.
  62. """
  63. def __init__(self,
  64. data,
  65. transform=None,
  66. target_transform=None,
  67. process_x=default_process_x,
  68. process_y=default_process_x,
  69. simulated=False,
  70. do_simulate=True,
  71. num_of_clients=10,
  72. simulation_method=SIMULATE_IID,
  73. weights=None,
  74. alpha=0.5,
  75. min_size=10,
  76. class_per_client=1):
  77. super(FederatedTensorDataset, self).__init__()
  78. self.simulated = simulated
  79. self.data = data
  80. self._validate_data(self.data)
  81. self.process_x = process_x
  82. self.process_y = process_y
  83. self.transform = transform
  84. self.target_transform = target_transform
  85. if simulated:
  86. self._users = sorted(list(self.data.keys()))
  87. elif do_simulate:
  88. # For simulation method provided, we support testing in server for now
  89. # TODO: support simulation for test data => test in clients
  90. self.simulation(num_of_clients, simulation_method, weights, alpha, min_size, class_per_client)
  91. def simulation(self, num_of_clients, niid=SIMULATE_IID, weights=None, alpha=0.5, min_size=10, class_per_client=1):
  92. if self.simulated:
  93. logger.warning("The dataset is already simulated, the simulation would not proceed.")
  94. return
  95. self._users, self.data = data_simulation(
  96. self.data['x'],
  97. self.data['y'],
  98. num_of_clients,
  99. niid,
  100. weights,
  101. alpha,
  102. min_size,
  103. class_per_client)
  104. self.simulated = True
  105. def loader(self, batch_size, client_id=None, shuffle=True, seed=0, transform=None, drop_last=False):
  106. """Get dataset loader.
  107. Args:
  108. batch_size (int): The batch size.
  109. client_id (str, optional): The id of client.
  110. shuffle (bool, optional): Whether to shuffle before batching.
  111. seed (int, optional): The shuffle seed.
  112. transform (torchvision.transforms.transforms.Compose, optional): Data transformation.
  113. drop_last (bool, optional): Whether to drop the last batch if its size is smaller than batch size.
  114. Returns:
  115. torch.utils.data.DataLoader: The data loader to load data.
  116. """
  117. # Simulation need to be done before creating a data loader
  118. if client_id is None:
  119. data = self.data
  120. else:
  121. data = self.data[client_id]
  122. data_x = data['x']
  123. data_y = data['y']
  124. data_x = np.array(data_x)
  125. data_y = np.array(data_y)
  126. data_x = self._input_process(data_x)
  127. data_y = self._label_process(data_y)
  128. if shuffle:
  129. np.random.seed(seed)
  130. rng_state = np.random.get_state()
  131. np.random.shuffle(data_x)
  132. np.random.set_state(rng_state)
  133. np.random.shuffle(data_y)
  134. transform = self.transform if transform is None else transform
  135. if transform is not None:
  136. dataset = TransformDataset(data_x,
  137. data_y,
  138. transform_x=transform,
  139. transform_y=self.target_transform)
  140. else:
  141. dataset = TensorDataset(data_x, data_y)
  142. loader = DataLoader(dataset=dataset,
  143. batch_size=batch_size,
  144. shuffle=shuffle,
  145. drop_last=drop_last)
  146. return loader
  147. @property
  148. def users(self):
  149. return self._users
  150. @users.setter
  151. def users(self, value):
  152. self._users = value
  153. def size(self, cid=None):
  154. if cid is not None:
  155. return len(self.data[cid]['y'])
  156. else:
  157. return len(self.data['y'])
  158. def total_size(self):
  159. if 'y' in self.data:
  160. return len(self.data['y'])
  161. else:
  162. return sum([len(self.data[i]['y']) for i in self.data])
  163. def _input_process(self, sample):
  164. if self.process_x is not None:
  165. sample = self.process_x(sample)
  166. return sample
  167. def _label_process(self, label):
  168. if self.process_y is not None:
  169. label = self.process_y(label)
  170. return label
  171. def _validate_data(self, data):
  172. if self.simulated:
  173. for i in data:
  174. assert len(data[i]['x']) == len(data[i]['y'])
  175. else:
  176. assert len(data['x']) == len(data['y'])
  177. class FederatedImageDataset(FederatedDataset):
  178. """
  179. Federated image dataset, data of clients are in format of image folder.
  180. Args:
  181. root (str|list[str]): The root directory or directories of image data folder.
  182. If the dataset is simulated to multiple clients, the root is a list of directories.
  183. Otherwise, it is the directory of an image data folder.
  184. simulated (bool): Whether the dataset is simulated to federated learning settings.
  185. do_simulate (bool, optional): Whether conduct simulation. It is only effective if it is not simulated.
  186. extensions (list[str], optional): A list of allowed image extensions.
  187. Only one of `extensions` and `is_valid_file` can be specified.
  188. is_valid_file (function, optional): A function that takes path of an Image file and check if it is valid.
  189. Only one of `extensions` and `is_valid_file` can be specified.
  190. transform (torchvision.transforms.transforms.Compose, optional): Transformation for data.
  191. target_transform (torchvision.transforms.transforms.Compose, optional): Transformation for data labels.
  192. num_of_clients (int, optional): number of clients for simulation. Only need if doing simulation.
  193. simulation_method(optional): split method. Only need if doing simulation.
  194. weights (list[float], optional): The targeted distribution of quantities to simulate quantity heterogeneity.
  195. The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
  196. The `num_of_clients` should be divisible by `len(weights)`.
  197. None means clients are simulated with the same data quantity.
  198. alpha (float, optional): The parameter for Dirichlet distribution simulation, only for dir simulation.
  199. min_size (int, optional): The minimal number of samples in each client, only for dir simulation.
  200. class_per_client (int, optional): The number of classes in each client, only for non-iid by class simulation.
  201. client_ids (list[str], optional): A list of client ids.
  202. Each client id matches with an element in roots.
  203. The client ids are ["f0000001", "f00000002", ...] if not specified.
  204. """
  205. def __init__(self,
  206. root,
  207. simulated,
  208. do_simulate=True,
  209. extensions=IMG_EXTENSIONS,
  210. is_valid_file=None,
  211. transform=None,
  212. target_transform=None,
  213. client_ids="default",
  214. num_of_clients=10,
  215. simulation_method=SIMULATE_IID,
  216. weights=None,
  217. alpha=0.5,
  218. min_size=10,
  219. class_per_client=1):
  220. super(FederatedImageDataset, self).__init__()
  221. self.simulated = simulated
  222. self.transform = transform
  223. self.target_transform = target_transform
  224. if self.simulated:
  225. self.data = {}
  226. self.classes = {}
  227. self.class_to_idx = {}
  228. self.roots = root
  229. self.num_of_clients = len(self.roots)
  230. if client_ids == "default":
  231. self.users = ["f%07.0f" % (i) for i in range(len(self.roots))]
  232. else:
  233. self.users = client_ids
  234. for i in range(self.num_of_clients):
  235. current_client_id = self.users[i]
  236. classes, class_to_idx = self._find_classes(self.roots[i])
  237. samples = make_dataset(self.roots[i], class_to_idx, extensions, is_valid_file)
  238. if len(samples) == 0:
  239. msg = "Found 0 files in subfolders of: {}\n".format(self.root)
  240. if extensions is not None:
  241. msg += "Supported extensions are: {}".format(",".join(extensions))
  242. raise RuntimeError(msg)
  243. self.classes[current_client_id] = classes
  244. self.class_to_idx[current_client_id] = class_to_idx
  245. temp_client = {'x': [i[0] for i in samples], 'y': [i[1] for i in samples]}
  246. self.data[current_client_id] = temp_client
  247. elif do_simulate:
  248. self.root = root
  249. classes, class_to_idx = self._find_classes(self.root)
  250. samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
  251. if len(samples) == 0:
  252. msg = "Found 0 files in subfolders of: {}\n".format(self.root)
  253. if extensions is not None:
  254. msg += "Supported extensions are: {}".format(",".join(extensions))
  255. raise RuntimeError(msg)
  256. self.extensions = extensions
  257. self.classes = classes
  258. self.class_to_idx = class_to_idx
  259. self.samples = samples
  260. self.inputs = [i[0] for i in self.samples]
  261. self.labels = [i[1] for i in self.samples]
  262. self.simulation(num_of_clients, simulation_method, weights, alpha, min_size, class_per_client)
  263. def simulation(self, num_of_clients, niid="iid", weights=[1], alpha=0.5, min_size=10, class_per_client=1):
  264. if self.simulated:
  265. logger.warning("The dataset is already simulated, the simulation would not proceed.")
  266. return
  267. self.users, self.data = data_simulation(self.inputs,
  268. self.labels,
  269. num_of_clients,
  270. niid,
  271. weights,
  272. alpha,
  273. min_size,
  274. class_per_client)
  275. self.simulated = True
  276. def loader(self, batch_size, client_id=None, shuffle=True, seed=0, num_workers=2, transform=None):
  277. """Get dataset loader.
  278. Args:
  279. batch_size (int): The batch size.
  280. client_id (str, optional): The id of client.
  281. shuffle (bool, optional): Whether to shuffle before batching.
  282. seed (int, optional): The shuffle seed.
  283. transform (torchvision.transforms.transforms.Compose, optional): Data transformation.
  284. num_workers (int, optional): The number of workers for dataset loader.
  285. Returns:
  286. torch.utils.data.DataLoader: The data loader to load data.
  287. """
  288. assert self.simulated is True
  289. if client_id is None:
  290. data = self.data
  291. else:
  292. data = self.data[client_id]
  293. data_x = data['x'][:]
  294. data_y = data['y'][:]
  295. # randomly shuffle data
  296. if shuffle:
  297. np.random.seed(seed)
  298. rng_state = np.random.get_state()
  299. np.random.shuffle(data_x)
  300. np.random.set_state(rng_state)
  301. np.random.shuffle(data_y)
  302. transform = self.transform if transform is None else transform
  303. dataset = ImageDataset(data_x, data_y, transform, self.target_transform)
  304. loader = torch.utils.data.DataLoader(dataset,
  305. batch_size=batch_size,
  306. shuffle=shuffle,
  307. num_workers=num_workers,
  308. pin_memory=False)
  309. return loader
  310. @property
  311. def users(self):
  312. return self._users
  313. @users.setter
  314. def users(self, value):
  315. self._users = value
  316. def size(self, cid=None):
  317. if cid is not None:
  318. return len(self.data[cid]['y'])
  319. else:
  320. return len(self.data['y'])
  321. def _find_classes(self, dir):
  322. """Get the classes of the dataset.
  323. Args:
  324. dir (str): Root directory path.
  325. Returns:
  326. tuple: (classes, class_to_idx) where classes are relative to directory and class_to_idx is a dictionary.
  327. Note:
  328. Need to ensure that no class is a subdirectory of another.
  329. """
  330. classes = [d.name for d in os.scandir(dir) if d.is_dir()]
  331. classes.sort()
  332. class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
  333. return classes, class_to_idx
  334. class FederatedTorchDataset(FederatedDataset):
  335. """Wrapper over PyTorch dataset.
  336. Args:
  337. data (dict): A dictionary of client datasets, format {"client_id": loader1, "client_id2": loader2}.
  338. """
  339. def __init__(self, data, users):
  340. super(FederatedTorchDataset, self).__init__()
  341. self.data = data
  342. self._users = users
  343. def loader(self, batch_size, client_id=None, shuffle=True, seed=0, num_workers=2, transform=None):
  344. if client_id is None:
  345. data = self.data
  346. else:
  347. data = self.data[client_id]
  348. loader = torch.utils.data.DataLoader(
  349. data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
  350. return loader
  351. @property
  352. def users(self):
  353. return self._users
  354. @users.setter
  355. def users(self, value):
  356. self._users = value
  357. def size(self, cid=None):
  358. if cid is not None:
  359. return len(self.data[cid])
  360. else:
  361. return len(self.data)