remove_users.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. """
  2. Removes users with less than the given number of samples.
  3. These codes are adopted from LEAF with some modifications.
  4. """
  5. import json
  6. import logging
  7. import os
  8. logger = logging.getLogger(__name__)
  9. def remove(setting_folder, dataset, min_samples):
  10. parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
  11. dir = os.path.join(parent_path, dataset, "data")
  12. subdir = os.path.join(dir, setting_folder, "sampled_data")
  13. files = []
  14. if os.path.exists(subdir):
  15. files = os.listdir(subdir)
  16. if len(files) == 0:
  17. subdir = os.path.join(dir, "all_data")
  18. files = os.listdir(subdir)
  19. files = [f for f in files if f.endswith(".json")]
  20. for f in files:
  21. users = []
  22. hierarchies = []
  23. num_samples = []
  24. user_data = {}
  25. file_dir = os.path.join(subdir, f)
  26. with open(file_dir, "r") as inf:
  27. data = json.load(inf)
  28. num_users = len(data["users"])
  29. for i in range(num_users):
  30. curr_user = data["users"][i]
  31. curr_hierarchy = None
  32. if "hierarchies" in data:
  33. curr_hierarchy = data["hierarchies"][i]
  34. curr_num_samples = data["num_samples"][i]
  35. if (curr_num_samples >= min_samples):
  36. user_data[curr_user] = data["user_data"][curr_user]
  37. users.append(curr_user)
  38. if curr_hierarchy is not None:
  39. hierarchies.append(curr_hierarchy)
  40. num_samples.append(data["num_samples"][i])
  41. all_data = {}
  42. all_data["users"] = users
  43. if len(hierarchies) == len(users):
  44. all_data["hierarchies"] = hierarchies
  45. all_data["num_samples"] = num_samples
  46. all_data["user_data"] = user_data
  47. file_name = "{}_keep_{}.json".format((f[:-5]), min_samples)
  48. ouf_dir = os.path.join(dir, setting_folder, "rem_user_data", file_name)
  49. logger.info("writing {}".format(file_name))
  50. with open(ouf_dir, "w") as outfile:
  51. json.dump(all_data, outfile)