server.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import copy
  2. import logging
  3. import os
  4. import torch
  5. import torch.distributed as dist
  6. from torchvision import datasets
  7. import model
  8. import utils
  9. from communication import TARGET
  10. from easyfl.datasets.data import CIFAR100
  11. from easyfl.distributed import reduce_models
  12. from easyfl.distributed.distributed import CPU
  13. from easyfl.server import strategies
  14. from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
  15. from easyfl.tracking import metric
  16. from knn_monitor import knn_monitor
  17. logger = logging.getLogger(__name__)
  18. class FedSSLServer(BaseServer):
  19. def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
  20. super(FedSSLServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
  21. self.train_loader = None
  22. self.test_loader = None
  23. def aggregation(self):
  24. if self.conf.client.auto_scaler == 'y' and self.conf.server.random_selection:
  25. self._retain_weight_scaler()
  26. uploaded_content = self.get_client_uploads()
  27. models = list(uploaded_content[MODEL].values())
  28. weights = list(uploaded_content[DATA_SIZE].values())
  29. # Aggregate networks gradually with different components.
  30. if self.conf.model in [model.Symmetric, model.SymmetricNoSG, model.SimSiam, model.SimSiamNoSG, model.BYOL,
  31. model.BYOLNoSG, model.BYOLNoPredictor, model.SimCLR]:
  32. online_encoders = [m.online_encoder for m in models]
  33. online_encoder = self._federated_averaging(online_encoders, weights)
  34. self._model.online_encoder.load_state_dict(online_encoder.state_dict())
  35. if self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  36. predictors = [m.online_predictor for m in models]
  37. predictor = self._federated_averaging(predictors, weights)
  38. self._model.online_predictor.load_state_dict(predictor.state_dict())
  39. if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  40. target_encoders = [m.target_encoder for m in models]
  41. target_encoder = self._federated_averaging(target_encoders, weights)
  42. self._model.target_encoder = copy.deepcopy(target_encoder)
  43. if self.conf.model in [model.MoCo, model.MoCoV2]:
  44. encoder_qs = [m.encoder_q for m in models]
  45. encoder_q = self._federated_averaging(encoder_qs, weights)
  46. self._model.encoder_q.load_state_dict(encoder_q.state_dict())
  47. encoder_ks = [m.encoder_k for m in models]
  48. encoder_k = self._federated_averaging(encoder_ks, weights)
  49. self._model.encoder_k.load_state_dict(encoder_k.state_dict())
  50. def _retain_weight_scaler(self):
  51. self.client_id_to_index = {c.cid: i for i, c in enumerate(self._clients)}
  52. client_index = self.client_id_to_index[self.grouped_clients[0].cid]
  53. weight_scaler = self.grouped_clients[0].weight_scaler if self.grouped_clients[0].weight_scaler else 0
  54. scaler = torch.tensor((client_index, weight_scaler)).to(self.conf.device)
  55. scalers = [torch.zeros_like(scaler) for _ in self.selected_clients]
  56. dist.barrier()
  57. dist.all_gather(scalers, scaler)
  58. logger.info(f"Synced scaler {scalers}")
  59. for i, client in enumerate(self._clients):
  60. for scaler in scalers:
  61. scaler = scaler.cpu().numpy()
  62. if self.client_id_to_index[client.cid] == int(scaler[0]) and not client.weight_scaler:
  63. self._clients[i].weight_scaler = scaler[1]
  64. def _federated_averaging(self, models, weights):
  65. fn_average = strategies.federated_averaging
  66. fn_sum = strategies.weighted_sum
  67. fn_reduce = reduce_models
  68. if self.conf.is_distributed:
  69. dist.barrier()
  70. model_, sample_sum = fn_sum(models, weights)
  71. fn_reduce(model_, torch.tensor(sample_sum).to(self.conf.device))
  72. else:
  73. model_ = fn_average(models, weights)
  74. return model_
  75. def test_in_server(self, device=CPU):
  76. testing_model = self._get_testing_model()
  77. testing_model.eval()
  78. testing_model.to(device)
  79. self._get_test_data()
  80. with torch.no_grad():
  81. accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader)
  82. test_results = {
  83. metric.TEST_ACCURACY: float(accuracy),
  84. metric.TEST_LOSS: 0,
  85. }
  86. return test_results
  87. def _get_test_data(self):
  88. transformation = self._load_transform()
  89. if self.train_loader is None or self.test_loader is None:
  90. if self.conf.data.dataset == CIFAR100:
  91. data_path = "./data/cifar100"
  92. train_dataset = datasets.CIFAR100(data_path, download=True, transform=transformation)
  93. test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
  94. else:
  95. data_path = "./data/cifar10"
  96. train_dataset = datasets.CIFAR10(data_path, download=True, transform=transformation)
  97. test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
  98. if self.train_loader is None:
  99. self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, num_workers=8)
  100. if self.test_loader is None:
  101. self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, num_workers=8)
  102. def _load_transform(self):
  103. transformation = utils.get_transformation(self.conf.model)
  104. return transformation().test_transform
  105. def _get_testing_model(self, net=False):
  106. if self.conf.model in [model.MoCo, model.MoCoV2]:
  107. testing_model = self._model.encoder_q
  108. elif self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.Symmetric, model.SymmetricNoSG, model.SimCLR]:
  109. testing_model = self._model.online_encoder
  110. else:
  111. # BYOL
  112. if self.conf.client.aggregate_encoder == TARGET:
  113. self.print_("Use aggregated target encoder for testing")
  114. testing_model = self._model.target_encoder
  115. else:
  116. self.print_("Use aggregated online encoder for testing")
  117. testing_model = self._model.online_encoder
  118. return testing_model
  119. def save_model(self):
  120. if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and self.is_primary_server():
  121. save_path = self.conf.server.save_model_path
  122. if save_path == "":
  123. save_path = os.path.join(os.getcwd(), "saved_models", self.conf.task_id)
  124. os.makedirs(save_path, exist_ok=True)
  125. save_path = os.path.join(save_path,
  126. "{}_global_model_r_{}.pth".format(self.conf.task_id, self._current_round))
  127. torch.save(self._get_testing_model().cpu().state_dict(), save_path)
  128. self.print_("Encoder model saved at {}".format(save_path))
  129. if self.conf.server.save_predictor:
  130. if self.conf.model in [model.SimSiam, model.BYOL]:
  131. save_path = save_path.replace("global_model", "predictor")
  132. torch.save(self._model.online_predictor.cpu().state_dict(), save_path)
  133. self.print_("Predictor model saved at {}".format(save_path))