shakespeare.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import logging
  2. import os
  3. from easyfl.datasets.shakespeare.utils.gen_all_data import generated_all_data
  4. from easyfl.datasets.shakespeare.utils.preprocess_shakespeare import shakespeare_preprocess
  5. from easyfl.datasets.utils.base_dataset import BaseDataset
  6. from easyfl.datasets.utils.download import download_url, extract_archive, download_from_google_drive
  7. logger = logging.getLogger(__name__)
  8. class Shakespeare(BaseDataset):
  9. """Shakespeare dataset implementation. It gets Shakespeare dataset according to configurations.
  10. Attributes:
  11. base_folder (str): The base folder path of the datasets folder.
  12. raw_data_url (str): The url to get the `by_class` split shakespeare.
  13. write_url (str): The url to get the `by_write` split shakespeare.
  14. """
  15. def __init__(self,
  16. root,
  17. fraction,
  18. split_type,
  19. user,
  20. iid_user_fraction=0.1,
  21. train_test_split=0.9,
  22. minsample=10,
  23. num_class=80,
  24. num_of_client=100,
  25. class_per_client=2,
  26. setting_folder=None,
  27. seed=-1,
  28. **kwargs):
  29. super(Shakespeare, self).__init__(root,
  30. "shakespeare",
  31. fraction,
  32. split_type,
  33. user,
  34. iid_user_fraction,
  35. train_test_split,
  36. minsample,
  37. num_class,
  38. num_of_client,
  39. class_per_client,
  40. setting_folder,
  41. seed)
  42. self.raw_data_url = "http://www.gutenberg.org/files/100/old/1994-01-100.zip"
  43. self.packaged_data_files = {
  44. "shakespeare_niid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/5qr9ozziy3yfzss/shakespeare_niid_100_10_1_0.05_0.1_sample_0.9.zip",
  45. "shakespeare_iid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/4p7osgjd2pecsi3/shakespeare_iid_100_10_1_0.05_0.1_sample_0.9.zip"
  46. }
  47. # Google drive ids.
  48. # self.packaged_data_files = {
  49. # "shakespeare_niid_100_10_1_0.05_0.1_sample_0.9.zip": "1zvmNiUNu7r0h4t0jBhOJ204qyc61NvfJ",
  50. # "shakespeare_iid_100_10_1_0.05_0.1_sample_0.9.zip": "1Lb8n1zDtrj2DX_QkjNnL6DH5IrnYFdsR"
  51. # }
  52. def download_packaged_dataset_and_extract(self, filename):
  53. file_path = download_url(self.packaged_data_files[filename], self.base_folder)
  54. extract_archive(file_path, remove_finished=True)
  55. def download_raw_file_and_extract(self):
  56. raw_data_folder = os.path.join(self.base_folder, "raw_data")
  57. if not os.path.exists(raw_data_folder):
  58. os.makedirs(raw_data_folder)
  59. elif os.listdir(raw_data_folder):
  60. logger.info("raw file exists")
  61. return
  62. raw_data_path = download_url(self.raw_data_url, raw_data_folder)
  63. extract_archive(raw_data_path, remove_finished=True)
  64. os.rename(os.path.join(raw_data_folder, "100.txt"), os.path.join(raw_data_folder, "raw_data.txt"))
  65. logger.info("raw file is downloaded")
  66. def preprocess(self):
  67. filename = os.path.join(self.base_folder, "raw_data", "raw_data.txt")
  68. raw_data_folder = os.path.join(self.base_folder, "raw_data")
  69. if not os.path.exists(raw_data_folder):
  70. os.makedirs(raw_data_folder)
  71. shakespeare_preprocess(filename, raw_data_folder)
  72. def convert_data_to_json(self):
  73. all_data_folder = os.path.join(self.base_folder, "all_data")
  74. if not os.path.exists(all_data_folder):
  75. os.makedirs(all_data_folder)
  76. if not os.listdir(all_data_folder):
  77. logger.info("converting data to .json format")
  78. generated_all_data(self.base_folder)
  79. logger.info("finished converting data to .json format")