split_data.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. """
  2. These codes are adopted from LEAF with some modifications.
  3. Splits data into train and test sets.
  4. """
  5. import json
  6. import logging
  7. import os
  8. import random
  9. import sys
  10. import time
  11. from collections import OrderedDict
  12. from easyfl.datasets.utils.constants import SEED_FILES
  13. logger = logging.getLogger(__name__)
  14. def create_jsons_for(dir, setting_folder, user_files, which_set, max_users, include_hierarchy, subdir, arg_label):
  15. """Used in split-by-user case"""
  16. user_count = 0
  17. json_index = 0
  18. users = []
  19. if include_hierarchy:
  20. hierarchies = []
  21. else:
  22. hierarchies = None
  23. num_samples = []
  24. user_data = {}
  25. for (i, t) in enumerate(user_files):
  26. if include_hierarchy:
  27. (u, h, ns, f) = t
  28. else:
  29. (u, ns, f) = t
  30. file_dir = os.path.join(subdir, f)
  31. with open(file_dir, 'r') as inf:
  32. data = json.load(inf)
  33. users.append(u)
  34. if include_hierarchy:
  35. hierarchies.append(h)
  36. num_samples.append(ns)
  37. user_data[u] = data['user_data'][u]
  38. user_count += 1
  39. if (user_count == max_users) or (i == len(user_files) - 1):
  40. all_data = {}
  41. all_data['users'] = users
  42. if include_hierarchy:
  43. all_data['hierarchies'] = hierarchies
  44. all_data['num_samples'] = num_samples
  45. all_data['user_data'] = user_data
  46. data_i = f.find('data')
  47. num_i = data_i + 5
  48. num_to_end = f[num_i:]
  49. param_i = num_to_end.find('_')
  50. param_to_end = '.json'
  51. if param_i != -1:
  52. param_to_end = num_to_end[param_i:]
  53. nf = "{}_{}{}".format(f[:(num_i - 1)], json_index, param_to_end)
  54. file_name = '{}_{}_{}.json'.format((nf[:-5]), which_set, arg_label)
  55. ouf_dir = os.path.join(dir, setting_folder, which_set, file_name)
  56. logger.info('writing {}'.format(file_name))
  57. with open(ouf_dir, 'w') as outfile:
  58. json.dump(all_data, outfile)
  59. user_count = 0
  60. json_index += 1
  61. users = []
  62. if include_hierarchy:
  63. hierarchies = []
  64. num_samples = []
  65. user_data = {}
  66. def split_train_test(setting_folder, metafile, name, user, frac, seed):
  67. logger.info("------------------------------")
  68. logger.info("generating training and test sets")
  69. parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
  70. dir = os.path.join(parent_path, name, 'data')
  71. subdir = os.path.join(dir, setting_folder, 'rem_user_data')
  72. files = []
  73. if os.path.exists(subdir):
  74. files = os.listdir(subdir)
  75. if len(files) == 0:
  76. subdir = os.path.join(dir, setting_folder, 'sampled_data')
  77. if os.path.exists(subdir):
  78. files = os.listdir(subdir)
  79. if len(files) == 0:
  80. subdir = os.path.join(dir, 'all_data')
  81. files = os.listdir(subdir)
  82. files = [f for f in files if f.endswith('.json')]
  83. rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
  84. rng = random.Random(rng_seed)
  85. if metafile is not None:
  86. seed_fname = os.path.join(metafile, SEED_FILES['split'])
  87. with open(seed_fname, 'w+') as f:
  88. f.write("# split_seed used by sampling script - supply as "
  89. "--spltseed to preprocess.sh or --seed to utils/split_data.py\n")
  90. f.write(str(rng_seed))
  91. logger.info("- random seed written out to {file}".format(file=seed_fname))
  92. else:
  93. logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))
  94. arg_label = str(frac)
  95. arg_label = arg_label[2:]
  96. # check if data contains information on hierarchies
  97. file_dir = os.path.join(subdir, files[0])
  98. with open(file_dir, 'r') as inf:
  99. data = json.load(inf)
  100. include_hierarchy = 'hierarchies' in data
  101. if (user):
  102. logger.info("splitting data by user")
  103. # 1 pass through all the json files to instantiate arr
  104. # containing all possible (user, .json file name) tuples
  105. user_files = []
  106. for f in files:
  107. file_dir = os.path.join(subdir, f)
  108. with open(file_dir, 'r') as inf:
  109. # Load data into an OrderedDict, to prevent ordering changes
  110. # and enable reproducibility
  111. data = json.load(inf, object_pairs_hook=OrderedDict)
  112. if include_hierarchy:
  113. user_files.extend([(u, h, ns, f) for (u, h, ns) in
  114. zip(data['users'], data['hierarchies'], data['num_samples'])])
  115. else:
  116. user_files.extend([(u, ns, f) for (u, ns) in
  117. zip(data['users'], data['num_samples'])])
  118. # randomly sample from user_files to pick training set users
  119. num_users = len(user_files)
  120. num_train_users = int(frac * num_users)
  121. indices = [i for i in range(num_users)]
  122. train_indices = rng.sample(indices, num_train_users)
  123. train_blist = [False for i in range(num_users)]
  124. for i in train_indices:
  125. train_blist[i] = True
  126. train_user_files = []
  127. test_user_files = []
  128. for i in range(num_users):
  129. if (train_blist[i]):
  130. train_user_files.append(user_files[i])
  131. else:
  132. test_user_files.append(user_files[i])
  133. max_users = sys.maxsize
  134. if name == 'femnist':
  135. max_users = 50 # max number of users per json file
  136. create_jsons_for(dir, setting_folder, train_user_files, 'train', max_users, include_hierarchy, subdir,
  137. arg_label)
  138. create_jsons_for(dir, setting_folder, test_user_files, 'test', max_users, include_hierarchy, subdir, arg_label)
  139. else:
  140. logger.info("splitting data by sample")
  141. for f in files:
  142. file_dir = os.path.join(subdir, f)
  143. with open(file_dir, 'r') as inf:
  144. # Load data into an OrderedDict, to prevent ordering changes
  145. # and enable reproducibility
  146. data = json.load(inf, object_pairs_hook=OrderedDict)
  147. num_samples_train = []
  148. user_data_train = {}
  149. num_samples_test = []
  150. user_data_test = {}
  151. user_indices = [] # indices of users in data['users'] that are not deleted
  152. for i, u in enumerate(data['users']):
  153. user_data_train[u] = {'x': [], 'y': []}
  154. user_data_test[u] = {'x': [], 'y': []}
  155. curr_num_samples = len(data['user_data'][u]['y'])
  156. if curr_num_samples >= 2:
  157. user_indices.append(i)
  158. # ensures number of train and test samples both >= 1
  159. num_train_samples = max(1, int(frac * curr_num_samples))
  160. if curr_num_samples == 2:
  161. num_train_samples = 1
  162. num_test_samples = curr_num_samples - num_train_samples
  163. num_samples_train.append(num_train_samples)
  164. num_samples_test.append(num_test_samples)
  165. indices = [j for j in range(curr_num_samples)]
  166. train_indices = rng.sample(indices, num_train_samples)
  167. train_blist = [False for _ in range(curr_num_samples)]
  168. for j in train_indices:
  169. train_blist[j] = True
  170. for j in range(curr_num_samples):
  171. if (train_blist[j]):
  172. user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
  173. user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
  174. else:
  175. user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
  176. user_data_test[u]['y'].append(data['user_data'][u]['y'][j])
  177. users = [data['users'][i] for i in user_indices]
  178. all_data_train = {}
  179. all_data_train['users'] = users
  180. all_data_train['num_samples'] = num_samples_train
  181. all_data_train['user_data'] = user_data_train
  182. all_data_test = {}
  183. all_data_test['users'] = users
  184. all_data_test['num_samples'] = num_samples_test
  185. all_data_test['user_data'] = user_data_test
  186. if include_hierarchy:
  187. all_data_train['hierarchies'] = data['hierarchies']
  188. all_data_test['hierarchies'] = data['hierarchies']
  189. file_name_train = '{}_train_{}.json'.format((f[:-5]), arg_label)
  190. file_name_test = '{}_test_{}.json'.format((f[:-5]), arg_label)
  191. ouf_dir_train = os.path.join(dir, setting_folder, 'train', file_name_train)
  192. ouf_dir_test = os.path.join(dir, setting_folder, 'test', file_name_test)
  193. logger.info("writing {}".format(file_name_train))
  194. with open(ouf_dir_train, 'w') as outfile:
  195. json.dump(all_data_train, outfile)
  196. logger.info("writing {}".format(file_name_test))
  197. with open(ouf_dir_test, 'w') as outfile:
  198. json.dump(all_data_test, outfile)