server.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import copy
  2. import logging
  3. import os
  4. import shutil
  5. import time
  6. from collections import defaultdict
  7. import torch
  8. from dataset import DataPrefetcher
  9. from losses import get_losses
  10. from utils import AverageMeter
  11. from easyfl.distributed.distributed import CPU
  12. from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
  13. from easyfl.tracking import metric
  14. logger = logging.getLogger(__name__)
  15. class MASServer(BaseServer):
  16. def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
  17. super(MASServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
  18. self.train_loader = None
  19. self.test_loader = None
  20. self._progress_table = []
  21. self._stats = []
  22. self._loss_history = []
  23. self._current_loss = 9e9
  24. self._best_loss = 9e9
  25. self._best_model = None
  26. self._client_models = []
  27. def aggregation(self):
  28. uploaded_content = self.get_client_uploads()
  29. models = list(uploaded_content[MODEL].values())
  30. weights = list(uploaded_content[DATA_SIZE].values())
  31. # Cache client models for saving
  32. self._client_models = [copy.deepcopy(m).cpu() for m in models]
  33. # Aggregation
  34. model = self.aggregate(models, weights)
  35. self.set_model(model, load_dict=True)
  36. def test_in_server(self, device=CPU):
  37. # Validation
  38. val_loader = self.val_data.loader(
  39. batch_size=max(self.conf.server.batch_size // 2, 1),
  40. shuffle=False,
  41. seed=self.conf.seed)
  42. test_results, stats, progress = self.test_fn(val_loader, self._model, device)
  43. self._current_loss = float(stats['Loss'])
  44. self._stats.append(stats)
  45. self._loss_history.append(self._current_loss)
  46. self._progress_table.append(progress)
  47. logger.info(f"Validation statistics: {stats}")
  48. # Test
  49. if self._current_round == self.conf.server.rounds - 1:
  50. test_loader = self.test_data.loader(
  51. batch_size=max(self.conf.server.batch_size // 2, 1),
  52. shuffle=False,
  53. seed=self.conf.seed)
  54. _, stats, progress_table = self.test_fn(test_loader, self._model, device)
  55. logger.info(f"Testing statistics of last round: {stats}")
  56. if self._current_loss <= self._best_loss:
  57. logger.info(f"Last round {self._current_round} is the best round")
  58. else:
  59. _, stats, progress_table = self.test_fn(test_loader, self._best_model, device)
  60. logger.info(f"Testing statistics of best model: {stats}")
  61. return test_results
  62. def test_fn(self, loader, model, device=CPU):
  63. model.eval()
  64. model.to(device)
  65. criteria = get_losses(self.conf.client.task_str, self.conf.client.rotate_loss, self.conf.client.task_weights)
  66. average_meters = defaultdict(AverageMeter)
  67. epoch_start_time = time.time()
  68. batch_num = 0
  69. num_data_points = len(loader)
  70. prefetcher = DataPrefetcher(loader, device)
  71. # torch.cuda.empty_cache()
  72. with torch.no_grad():
  73. for i in range(len(loader)):
  74. input, target = prefetcher.next()
  75. if batch_num == 0:
  76. epoch_start_time2 = time.time()
  77. output = model(input)
  78. loss_dict = {}
  79. for c_name, criterion_fn in criteria.items():
  80. loss_dict[c_name] = criterion_fn(output, target)
  81. batch_num = i + 1
  82. for name, value in loss_dict.items():
  83. try:
  84. average_meters[name].update(value.data)
  85. except:
  86. average_meters[name].update(value)
  87. eta = ((time.time() - epoch_start_time2) / (batch_num + .2)) * (len(loader) - batch_num)
  88. to_print = {
  89. f'#/{num_data_points}': '{0}'.format(batch_num),
  90. 'eta': '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(eta))))
  91. }
  92. for name in criteria.keys():
  93. meter = average_meters[name]
  94. to_print[name] = '{meter.avg:.4f}'.format(meter=meter)
  95. epoch_time = time.time() - epoch_start_time
  96. stats = {'batches': len(loader), 'epoch_time': epoch_time}
  97. for name in criteria.keys():
  98. meter = average_meters[name]
  99. stats[name] = meter.avg
  100. to_print['eta'] = '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(epoch_time))))
  101. torch.cuda.empty_cache()
  102. test_results = {
  103. metric.TEST_ACCURACY: 0,
  104. metric.TEST_LOSS: float(stats['Loss']),
  105. }
  106. return test_results, stats, [to_print]
  107. def save_model(self):
  108. if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and \
  109. self.is_primary_server():
  110. save_path = self.conf.server.save_model_path
  111. if save_path == "":
  112. save_path = os.path.join(os.getcwd(), "saved_models", "mas", self.conf.task_id)
  113. os.makedirs(save_path, exist_ok=True)
  114. if self.conf.server.save_model_every == 1:
  115. save_filename = f"{self.conf.task_id}_checkpoint.pth.tar"
  116. else:
  117. save_filename = f"{self.conf.task_id}_r_{self._current_round}_checkpoint.pth.tar"
  118. # save_path = os.path.join(save_path, f"{self.conf.task_id}_r_{self._current_round}_checkpoint.pth.tar")
  119. is_best = self._current_loss < self._best_loss
  120. self._best_loss = min(self._current_loss, self._best_loss)
  121. try:
  122. checkpoint = {
  123. 'round': self._current_round,
  124. 'info': {'machine': self.conf.distributed.init_method, 'GPUS': self.conf.gpu},
  125. 'args': self.conf,
  126. 'arch': self.conf.arch,
  127. 'state_dict': self._model.cpu().state_dict(),
  128. 'best_loss': self._best_loss,
  129. 'progress_table': self._progress_table,
  130. 'stats': self._stats,
  131. 'loss_history': self._loss_history,
  132. 'code_archive': self.get_code_archive(),
  133. 'client_models': [m.cpu().state_dict() for m in self._client_models]
  134. }
  135. self.save_checkpoint(checkpoint, False, save_path, save_filename)
  136. if is_best:
  137. logger.info(f"Best validation loss at round {self._current_round}: {self._best_loss}")
  138. self._best_model = copy.deepcopy(self._model)
  139. self.save_checkpoint(None, True, save_path, save_filename)
  140. self.print_("Checkpoint saved at {}".format(save_path))
  141. except:
  142. self.print_('Save checkpoint failed...')
  143. def save_checkpoint(self, state, is_best, directory='', filename='checkpoint.pth.tar'):
  144. path = os.path.join(directory, filename)
  145. if is_best:
  146. best_path = os.path.join(directory, f"best_{self.conf.task_id}_checkpoint.pth.tar")
  147. shutil.copyfile(path, best_path)
  148. else:
  149. torch.save(state, path)
  150. def get_code_archive(self):
  151. file_contents = {}
  152. for i in os.listdir('.'):
  153. if i[-3:] == '.py':
  154. with open(i, 'r') as file:
  155. file_contents[i] = file.read()
  156. return file_contents