split.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import argparse
  2. import collections
  3. import copy
  4. import functools
  5. import json
  6. import operator
  7. import re
  8. from pprint import pprint
  9. import matplotlib
  10. import matplotlib.pyplot as plt
  11. matplotlib.rcParams['text.usetex'] = True
  12. plt.rcParams.update({'font.size': 30})
  13. import numpy as np
  14. import network_selection
  15. MAPPING = {
  16. 'ss_l': 's',
  17. 'depth_l': 'd',
  18. 'norm_l': 'n',
  19. 'key_l': 'k',
  20. 'edge2d_l': 't',
  21. 'edge_l': 'e',
  22. 'shade_l': 'r',
  23. 'rgb_l': 'a',
  24. 'pc_l': 'c',
  25. }
  26. COLOR_MAP = {
  27. 'ss_l': 'tab:blue',
  28. 'depth_l': 'tab:orange',
  29. 'norm_l': 'tab:green',
  30. 'key_l': 'tab:red',
  31. 'edge2d_l': 'tab:purple',
  32. 'edge_l': 'tab:brown',
  33. 'shade_l': 'tab:pink',
  34. 'rgb_l': 'tab:gray',
  35. 'pc_l': 'tab:olive',
  36. }
  37. class Affinity:
  38. def __init__(self, args=None):
  39. self.affinities = {}
  40. self.args = args
  41. self.task_overlap = args.task_overlap
  42. self.split = args.split
  43. def add(self, round_id, client_id, affinity):
  44. if self.args.preprocess:
  45. affinity = self.preprocess_affinity(affinity)
  46. for scores in affinity.values():
  47. if isinstance(scores, list) and scores[0]['ss_l'] == 0.0:
  48. return
  49. else:
  50. break
  51. if round_id not in self.affinities:
  52. self.affinities[round_id] = {client_id: affinity}
  53. else:
  54. self.affinities[round_id][client_id] = affinity
  55. def get_round_affinities(self, round_id):
  56. return list(self.affinities[round_id].values())
  57. def average_affinities(self, affinities):
  58. result = copy.deepcopy(affinities[0])
  59. for task, affinity in result.items():
  60. for target_task, score in affinity.items():
  61. total = score
  62. for a in affinities[1:]:
  63. total += a[task][target_task]
  64. result[task][target_task] = total / len(affinities)
  65. return result
  66. def average_affinity_of_clients(self, max_round=100):
  67. affinities = {}
  68. for round_id, affinity in self.affinities.items():
  69. if round_id >= max_round:
  70. continue
  71. result = self.average_affinities(list(affinity.values()))
  72. affinities[round_id] = result
  73. return affinities
  74. def average_affinity_of_rounds(self, max_round=100):
  75. affinities = self.average_affinity_of_clients(max_round)
  76. return self.average_affinities(list(affinities.values()))
  77. def preprocess_affinity(self, affinity):
  78. for task, scores in affinity.items():
  79. result = dict(functools.reduce(operator.add, map(collections.Counter, scores)))
  80. affinity[task] = result
  81. return affinity
  82. def network_selection(self, rounds, specific_round=False):
  83. results = {}
  84. # Network selection of specific round
  85. if specific_round:
  86. for round_id in rounds:
  87. round_affinities = self.get_round_affinities(round_id)
  88. # Network selection of average
  89. averaged_affinity = self.average_affinities(round_affinities)
  90. result = network_selection.task_grouping(averaged_affinity, task_overlap=self.task_overlap,
  91. split=self.split)
  92. results[round_id] = {"average": result}
  93. # pprint(averaged_affinity)
  94. if not self.args.average_only:
  95. for client, a in self.affinities[round_id].items():
  96. result = network_selection.task_grouping(a, task_overlap=self.task_overlap, split=self.split)
  97. results[round_id][client] = result
  98. # Average task affinity of all rounds
  99. for round_id in rounds:
  100. affinities = self.average_affinity_of_rounds(round_id)
  101. results[f"average_{round_id}"] = network_selection.task_grouping(affinities, task_overlap=self.task_overlap,
  102. split=self.split)
  103. # Convert string formats from loss to single letter
  104. return results
  105. def extract_task_affinity(line):
  106. r = re.search(r'[\d\w\-\[\]\:\ ,]* Round (\d+) - Client (\w+) transference: (\{[\{\}\[\]\'\-\_\d\w\: .,]*\}\n)',
  107. line)
  108. if not r:
  109. return
  110. return r.groups()
  111. def run(args):
  112. A = Affinity(args)
  113. with open(args.filename, 'r') as f:
  114. for line in f:
  115. data = extract_task_affinity(line)
  116. if not data:
  117. continue
  118. round_id, client_id, affinity = data
  119. round_id = int(round_id)
  120. affinity = affinity.replace("'", "\"")
  121. affinity = json.loads(affinity)
  122. A.add(round_id, client_id, affinity)
  123. else:
  124. results = A.network_selection(args.rounds)
  125. results_str = json.dumps(results)
  126. for loss_name, char in MAPPING.items():
  127. results_str = results_str.replace(loss_name, char)
  128. results = json.loads(results_str)
  129. pprint(results)
  130. def construct_analyze_parser(parser):
  131. parser.add_argument('-f', '--filename', type=str, metavar='PATH', default="./train.log")
  132. parser.add_argument('-s', '--split', type=int, default=3)
  133. parser.add_argument('-o', '--task_overlap', action='store_true')
  134. parser.add_argument('-p', '--preprocess', action='store_true')
  135. parser.add_argument('-a', '--average_only', action='store_true')
  136. parser.add_argument('-r', '--rounds', nargs="*", default=[10], type=int)
  137. return parser
  138. if __name__ == '__main__':
  139. parser = argparse.ArgumentParser(description='Split')
  140. parser = construct_analyze_parser(parser)
  141. args = parser.parse_args()
  142. print("args:", args)
  143. run(args)