femnist.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import logging
  2. import os
  3. from easyfl.datasets.femnist.preprocess.data_to_json import data_to_json
  4. from easyfl.datasets.femnist.preprocess.get_file_dirs import get_file_dir
  5. from easyfl.datasets.femnist.preprocess.get_hashes import get_hash
  6. from easyfl.datasets.femnist.preprocess.group_by_writer import group_by_writer
  7. from easyfl.datasets.femnist.preprocess.match_hashes import match_hash
  8. from easyfl.datasets.utils.base_dataset import BaseDataset
  9. from easyfl.datasets.utils.download import download_url, extract_archive, download_from_google_drive
  10. logger = logging.getLogger(__name__)
  11. class Femnist(BaseDataset):
  12. """FEMNIST dataset implementation. It gets FEMNIST dataset according to configurations.
  13. It stores the processed datasets locally.
  14. Attributes:
  15. base_folder (str): The base folder path of the datasets folder.
  16. class_url (str): The url to get the by_class split FEMNIST.
  17. write_url (str): The url to get the by_write split FEMNIST.
  18. """
  19. def __init__(self,
  20. root,
  21. fraction,
  22. split_type,
  23. user,
  24. iid_user_fraction=0.1,
  25. train_test_split=0.9,
  26. minsample=10,
  27. num_class=62,
  28. num_of_client=100,
  29. class_per_client=2,
  30. setting_folder=None,
  31. seed=-1,
  32. **kwargs):
  33. super(Femnist, self).__init__(root,
  34. "femnist",
  35. fraction,
  36. split_type,
  37. user,
  38. iid_user_fraction,
  39. train_test_split,
  40. minsample,
  41. num_class,
  42. num_of_client,
  43. class_per_client,
  44. setting_folder,
  45. seed)
  46. self.class_url = "https://s3.amazonaws.com/nist-srd/SD19/by_class.zip"
  47. self.write_url = "https://s3.amazonaws.com/nist-srd/SD19/by_write.zip"
  48. self.packaged_data_files = {
  49. "femnist_niid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/oyhegd3c0pxa0tl/femnist_niid_100_10_1_0.05_0.1_sample_0.9.zip",
  50. "femnist_iid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/jcg0xrz5qrri4tv/femnist_iid_100_10_1_0.05_0.1_sample_0.9.zip"
  51. }
  52. # Google Drive ids
  53. # self.packaged_data_files = {
  54. # "femnist_niid_100_10_1_0.05_0.1_sample_0.9.zip": "11vAxASl-af41iHpFqW2jixs1jOUZDXMS",
  55. # "femnist_iid_100_10_1_0.05_0.1_sample_0.9.zip": "1U9Sn2ACbidwhhihdJdZPfK2YddPMr33k"
  56. # }
  57. def download_packaged_dataset_and_extract(self, filename):
  58. file_path = download_url(self.packaged_data_files[filename], self.base_folder)
  59. extract_archive(file_path, remove_finished=True)
  60. def download_raw_file_and_extract(self):
  61. raw_data_folder = os.path.join(self.base_folder, "raw_data")
  62. if not os.path.exists(raw_data_folder):
  63. os.makedirs(raw_data_folder)
  64. elif os.listdir(raw_data_folder):
  65. logger.info("raw file exists")
  66. return
  67. class_path = download_url(self.class_url, raw_data_folder)
  68. write_path = download_url(self.write_url, raw_data_folder)
  69. extract_archive(class_path, remove_finished=True)
  70. extract_archive(write_path, remove_finished=True)
  71. logger.info("raw file is downloaded")
  72. def preprocess(self):
  73. intermediate_folder = os.path.join(self.base_folder, "intermediate")
  74. if not os.path.exists(intermediate_folder):
  75. os.makedirs(intermediate_folder)
  76. if not os.path.exists(intermediate_folder + "/class_file_dirs.pkl"):
  77. logger.info("extracting file directories of images")
  78. get_file_dir(self.base_folder)
  79. logger.info("finished extracting file directories of images")
  80. if not os.path.exists(intermediate_folder + "/class_file_hashes.pkl"):
  81. logger.info("calculating image hashes")
  82. get_hash(self.base_folder)
  83. logger.info("finished calculating image hashes")
  84. if not os.path.exists(intermediate_folder + "/write_with_class.pkl"):
  85. logger.info("assigning class labels to write images")
  86. match_hash(self.base_folder)
  87. logger.info("finished assigning class labels to write images")
  88. if not os.path.exists(intermediate_folder + "/images_by_writer.pkl"):
  89. logger.info("grouping images by writer")
  90. group_by_writer(self.base_folder)
  91. logger.info("finished grouping images by writer")
  92. def convert_data_to_json(self):
  93. all_data_folder = os.path.join(self.base_folder, "all_data")
  94. if not os.path.exists(all_data_folder):
  95. os.makedirs(all_data_folder)
  96. if not os.listdir(all_data_folder):
  97. logger.info("converting data to .json format")
  98. data_to_json(self.base_folder)
  99. logger.info("finished converting data to .json format")