sample.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. """
  2. These codes are adopted from LEAF with some modifications.
  3. Samples from all raw data;
  4. by default samples in a non-iid manner; namely, randomly selects users from
  5. raw data until their cumulative amount of data exceeds the given number of
  6. datapoints to sample (specified by --fraction argument);
  7. ordering of original data points is not preserved in sampled data
  8. """
  9. import json
  10. import logging
  11. import os
  12. import random
  13. import time
  14. from collections import OrderedDict
  15. from easyfl.datasets.simulation import non_iid_class
  16. from easyfl.datasets.utils.constants import SEED_FILES
  17. from easyfl.datasets.utils.util import iid_divide
  18. logger = logging.getLogger(__name__)
  19. def extreme(data_dir, data_folder, metafile, fraction, num_class=62, num_of_client=100, class_per_client=2, seed=-1):
  20. """
  21. Note: for extreme split, there are two ways, one is divide each class into parts and then distribute to the clients;
  22. The second way is to let clients to go through classes to get a part of the data; Current version is the latter one, we
  23. can also provide the previous one (the one we adopt in CIFA10); If (num_of_client*class_per_client)%num_class, there is no
  24. difference(assume each class is equal), otherwise, how to deal with some remain parts is a question to discuss. (currently,
  25. the method will just give the remain part to the next client coming for collection, which may make the last clients have more
  26. than class_per_client;)
  27. """
  28. logger.info("------------------------------")
  29. logger.info("sampling data")
  30. subdir = os.path.join(data_dir, 'all_data')
  31. files = os.listdir(subdir)
  32. files = [f for f in files if f.endswith('.json')]
  33. rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
  34. logger.info("Using seed {}".format(rng_seed))
  35. rng = random.Random(rng_seed)
  36. logger.info(metafile)
  37. if metafile is not None:
  38. seed_fname = os.path.join(metafile, SEED_FILES['sampling'])
  39. with open(seed_fname, 'w+') as f:
  40. f.write("# sampling_seed used by sampling script - supply as "
  41. "--smplseed to preprocess.sh or --seed to utils/sample.py\n")
  42. f.write(str(rng_seed))
  43. logger.info("- random seed written out to {file}".format(file=seed_fname))
  44. else:
  45. logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))
  46. new_user_count = 0 # for iid case
  47. all_users = []
  48. all_user_data = {}
  49. for f in files:
  50. file_dir = os.path.join(subdir, f)
  51. with open(file_dir, 'r') as inf:
  52. data = json.load(inf, object_pairs_hook=OrderedDict)
  53. num_users = len(data['users'])
  54. tot_num_samples = sum(data['num_samples'])
  55. num_new_samples = int(fraction * tot_num_samples)
  56. raw_list = list(data['user_data'].values())
  57. raw_x = [elem['x'] for elem in raw_list]
  58. raw_y = [elem['y'] for elem in raw_list]
  59. x_list = [item for sublist in raw_x for item in sublist] # flatten raw_x
  60. y_list = [item for sublist in raw_y for item in sublist] # flatten raw_y
  61. num_new_users = num_users
  62. indices = [i for i in range(tot_num_samples)]
  63. new_indices = rng.sample(indices, num_new_samples)
  64. users = [str(i + new_user_count) for i in range(num_new_users)]
  65. all_users.extend(users)
  66. user_data = {}
  67. for user in users:
  68. user_data[user] = {'x': [], 'y': []}
  69. all_x_samples = [x_list[i] for i in new_indices]
  70. all_y_samples = [y_list[i] for i in new_indices]
  71. x_groups = iid_divide(all_x_samples, num_new_users)
  72. y_groups = iid_divide(all_y_samples, num_new_users)
  73. for i in range(num_new_users):
  74. user_data[users[i]]['x'] = x_groups[i]
  75. user_data[users[i]]['y'] = y_groups[i]
  76. all_user_data.update(user_data)
  77. num_samples = [len(user_data[u]['y']) for u in users]
  78. new_user_count += num_new_users
  79. allx = []
  80. ally = []
  81. for i in all_users:
  82. allx.extend(all_user_data[i]['x'])
  83. ally.extend(all_user_data[i]['y'])
  84. clients, all_user_data = non_iid_class(x_list, y_list, class_per_client, num_of_client)
  85. # ------------
  86. # create .json file
  87. all_num_samples = []
  88. for i in clients:
  89. all_num_samples.append(len(all_user_data[i]['y']))
  90. all_data = {}
  91. all_data['users'] = clients
  92. all_data['num_samples'] = all_num_samples
  93. all_data['user_data'] = all_user_data
  94. slabel = ''
  95. arg_frac = str(fraction)
  96. arg_frac = arg_frac[2:]
  97. arg_label = arg_frac
  98. file_name = '%s_%s_%s.json' % ("class", slabel, arg_label)
  99. ouf_dir = os.path.join(data_folder, 'sampled_data', file_name)
  100. logger.info("writing {}".format(file_name))
  101. with open(ouf_dir, 'w') as outfile:
  102. json.dump(all_data, outfile)
  103. def sample(data_dir, data_folder, metafile, fraction, iid, iid_user_fraction=0.01, seed=-1):
  104. logger.info("------------------------------")
  105. logger.info("sampling data")
  106. subdir = os.path.join(data_dir, 'all_data')
  107. files = os.listdir(subdir)
  108. files = [f for f in files if f.endswith('.json')]
  109. rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
  110. logger.info("Using seed {}".format(rng_seed))
  111. rng = random.Random(rng_seed)
  112. logger.info(metafile)
  113. if metafile is not None:
  114. seed_fname = os.path.join(metafile, SEED_FILES['sampling'])
  115. with open(seed_fname, 'w+') as f:
  116. f.write("# sampling_seed used by sampling script - supply as "
  117. "--smplseed to preprocess.sh or --seed to utils/sample.py\n")
  118. f.write(str(rng_seed))
  119. logger.info("- random seed written out to {file}".format(file=seed_fname))
  120. else:
  121. logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))
  122. new_user_count = 0 # for iid case
  123. for f in files:
  124. file_dir = os.path.join(subdir, f)
  125. with open(file_dir, 'r') as inf:
  126. # Load data into an OrderedDict, to prevent ordering changes
  127. # and enable reproducibility
  128. data = json.load(inf, object_pairs_hook=OrderedDict)
  129. num_users = len(data['users'])
  130. tot_num_samples = sum(data['num_samples'])
  131. num_new_samples = int(fraction * tot_num_samples)
  132. hierarchies = None
  133. if iid:
  134. # iid in femnist is to put all data together, and then split them according to
  135. # iid_user_fraction * num_users numbers of clients evenly
  136. raw_list = list(data['user_data'].values())
  137. raw_x = [elem['x'] for elem in raw_list]
  138. raw_y = [elem['y'] for elem in raw_list]
  139. x_list = [item for sublist in raw_x for item in sublist] # flatten raw_x
  140. y_list = [item for sublist in raw_y for item in sublist] # flatten raw_y
  141. num_new_users = int(round(iid_user_fraction * num_users))
  142. if num_new_users == 0:
  143. num_new_users += 1
  144. indices = [i for i in range(tot_num_samples)]
  145. new_indices = rng.sample(indices, num_new_samples)
  146. users = ["f%07.0f" % (i + new_user_count) for i in range(num_new_users)]
  147. user_data = {}
  148. for user in users:
  149. user_data[user] = {'x': [], 'y': []}
  150. all_x_samples = [x_list[i] for i in new_indices]
  151. all_y_samples = [y_list[i] for i in new_indices]
  152. x_groups = iid_divide(all_x_samples, num_new_users)
  153. y_groups = iid_divide(all_y_samples, num_new_users)
  154. for i in range(num_new_users):
  155. user_data[users[i]]['x'] = x_groups[i]
  156. user_data[users[i]]['y'] = y_groups[i]
  157. num_samples = [len(user_data[u]['y']) for u in users]
  158. new_user_count += num_new_users
  159. else:
  160. # niid's fraction in femnist is to choose some clients, one by one,
  161. # until the data size meets the fration * total data size
  162. ctot_num_samples = 0
  163. users = data['users']
  164. users_and_hiers = None
  165. if 'hierarchies' in data:
  166. users_and_hiers = list(zip(users, data['hierarchies']))
  167. rng.shuffle(users_and_hiers)
  168. else:
  169. rng.shuffle(users)
  170. user_i = 0
  171. num_samples = []
  172. user_data = {}
  173. if 'hierarchies' in data:
  174. hierarchies = []
  175. while ctot_num_samples < num_new_samples:
  176. hierarchy = None
  177. if users_and_hiers is not None:
  178. user, hier = users_and_hiers[user_i]
  179. else:
  180. user = users[user_i]
  181. cdata = data['user_data'][user]
  182. cnum_samples = len(data['user_data'][user]['y'])
  183. if ctot_num_samples + cnum_samples > num_new_samples:
  184. cnum_samples = num_new_samples - ctot_num_samples
  185. indices = [i for i in range(cnum_samples)]
  186. new_indices = rng.sample(indices, cnum_samples)
  187. x = []
  188. y = []
  189. for i in new_indices:
  190. x.append(data['user_data'][user]['x'][i])
  191. y.append(data['user_data'][user]['y'][i])
  192. cdata = {'x': x, 'y': y}
  193. if 'hierarchies' in data:
  194. hierarchies.append(hier)
  195. num_samples.append(cnum_samples)
  196. user_data[user] = cdata
  197. ctot_num_samples += cnum_samples
  198. user_i += 1
  199. if 'hierarchies' in data:
  200. users = [u for u, h in users_and_hiers][:user_i]
  201. else:
  202. users = users[:user_i]
  203. # ------------
  204. # create .json file
  205. all_data = {}
  206. all_data['users'] = users
  207. if hierarchies is not None:
  208. all_data['hierarchies'] = hierarchies
  209. all_data['num_samples'] = num_samples
  210. all_data['user_data'] = user_data
  211. slabel = 'niid'
  212. if iid:
  213. slabel = 'iid'
  214. arg_frac = str(fraction)
  215. arg_frac = arg_frac[2:]
  216. arg_nu = str(iid_user_fraction)
  217. arg_nu = arg_nu[2:]
  218. arg_label = arg_frac
  219. if iid:
  220. arg_label = '%s_%s' % (arg_nu, arg_label)
  221. file_name = '%s_%s_%s.json' % ((f[:-5]), slabel, arg_label)
  222. ouf_dir = os.path.join(data_folder, 'sampled_data', file_name)
  223. logger.info('writing %s' % file_name)
  224. with open(ouf_dir, 'w') as outfile:
  225. json.dump(all_data, outfile)