cifar10.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import logging
  2. import os
  3. import torchvision
  4. from easyfl.datasets.simulation import data_simulation
  5. from easyfl.datasets.utils.base_dataset import BaseDataset, CIFAR10
  6. from easyfl.datasets.utils.util import save_dict
  7. logger = logging.getLogger(__name__)
  8. class Cifar10(BaseDataset):
  9. def __init__(self,
  10. root,
  11. fraction,
  12. split_type,
  13. user,
  14. iid_user_fraction=0.1,
  15. train_test_split=0.9,
  16. minsample=10,
  17. num_class=80,
  18. num_of_client=100,
  19. class_per_client=2,
  20. setting_folder=None,
  21. seed=-1,
  22. weights=None,
  23. alpha=0.5):
  24. super(Cifar10, self).__init__(root,
  25. CIFAR10,
  26. fraction,
  27. split_type,
  28. user,
  29. iid_user_fraction,
  30. train_test_split,
  31. minsample,
  32. num_class,
  33. num_of_client,
  34. class_per_client,
  35. setting_folder,
  36. seed)
  37. self.train_data, self.test_data = {}, {}
  38. self.split_type = split_type
  39. self.num_of_client = num_of_client
  40. self.weights = weights
  41. self.alpha = alpha
  42. self.min_size = minsample
  43. self.class_per_client = class_per_client
  44. def download_packaged_dataset_and_extract(self, filename):
  45. pass
  46. def download_raw_file_and_extract(self):
  47. train_set = torchvision.datasets.CIFAR10(root=self.base_folder, train=True, download=True)
  48. test_set = torchvision.datasets.CIFAR10(root=self.base_folder, train=False, download=True)
  49. self.train_data = {
  50. 'x': train_set.data,
  51. 'y': train_set.targets
  52. }
  53. self.test_data = {
  54. 'x': test_set.data,
  55. 'y': test_set.targets
  56. }
  57. def preprocess(self):
  58. train_data_path = os.path.join(self.data_folder, "train")
  59. test_data_path = os.path.join(self.data_folder, "test")
  60. if not os.path.exists(self.data_folder):
  61. os.makedirs(self.data_folder)
  62. if self.weights is None and os.path.exists(train_data_path):
  63. return
  64. logger.info("Start CIFAR10 data simulation")
  65. _, train_data = data_simulation(self.train_data['x'],
  66. self.train_data['y'],
  67. self.num_of_client,
  68. self.split_type,
  69. self.weights,
  70. self.alpha,
  71. self.min_size,
  72. self.class_per_client)
  73. logger.info("Complete CIFAR10 data simulation")
  74. save_dict(train_data, train_data_path)
  75. save_dict(self.test_data, test_data_path)
  76. def convert_data_to_json(self):
  77. pass