server_with_pgfed.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import copy
  2. import logging
  3. import os
  4. import torch
  5. import torch.distributed as dist
  6. from torchvision import datasets
  7. import time
  8. import model
  9. import utils
  10. from communication import TARGET
  11. from easyfl.datasets.data import CIFAR100
  12. from easyfl.distributed import reduce_models
  13. from easyfl.distributed.distributed import CPU
  14. from easyfl.server import strategies
  15. from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
  16. from easyfl.tracking import metric
  17. from easyfl.protocol import codec
  18. from knn_monitor import knn_monitor
  19. from server import FedSSLServer
  20. logger = logging.getLogger(__name__)
  21. class FedSSLWithPgFedServer(FedSSLServer):
  22. def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
  23. super(FedSSLWithPgFedServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
  24. self.train_loader = None
  25. self.test_loader = None
  26. self.mu = 0
  27. self.momentum = 0.0
  28. self.alpha_mat = None
  29. self.uploaded_grads = {}
  30. self.loss_minuses = {}
  31. self.mean_grad = None
  32. self.convex_comb_grad = None
  33. def set_clients(self, clients):
  34. self._clients = clients
  35. for i, _ in enumerate(self._clients):
  36. self._clients[i].id = i
  37. def train(self):
  38. """Training process of federated learning."""
  39. self.print_("--- start training ---")
  40. print(f"\nJoin clients / total clients: {self.conf.server.clients_per_round} / {len(self._clients)}")
  41. self.selection(self._clients, self.conf.server.clients_per_round)
  42. self.grouping_for_distributed()
  43. self.compression()
  44. begin_train_time = time.time()
  45. self.send_param()
  46. self.distribution_to_train()
  47. self.aggregation()
  48. self.get_mean_grad()
  49. train_time = time.time() - begin_train_time
  50. self.print_("Server train time: {}".format(train_time))
  51. self.track(metric.TRAIN_TIME, train_time)
  52. def send_param(self):
  53. if self.alpha_mat==None:
  54. self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
  55. for client in self.grouped_clients:
  56. client.a_i = self.alpha_mat[client.id]
  57. if len(self.uploaded_grads) == 0:
  58. return
  59. self.convex_comb_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
  60. for client in self.grouped_clients:
  61. client.set_prev_mean_grad(self.mean_grad)
  62. mu_a_i = self.alpha_mat[client.id] * self.mu
  63. grads, weights = [], []
  64. for clt_idx, grad in self.uploaded_grads.items():
  65. weights.append(mu_a_i[clt_idx])
  66. grads.append(grad)
  67. self.model_weighted_sum(self.convex_comb_grad, grads, weights)
  68. client.set_prev_convex_comb_grad(self.convex_comb_grad, momentum=self.momentum)
  69. client.prev_loss_minuses = copy.deepcopy(self.loss_minuses)
  70. def distribution_to_train_locally(self):
  71. """Conduct training sequentially for selected clients in the group."""
  72. uploaded_models = {}
  73. uploaded_weights = {}
  74. uploaded_metrics = []
  75. for client in self.grouped_clients:
  76. # Update client config before training
  77. self.conf.client.task_id = self.conf.task_id
  78. self.conf.client.round_id = self._current_round
  79. uploaded_request = client.run_train(self._compressed_model, self.conf.client)
  80. uploaded_content = uploaded_request.content
  81. model = self.decompression(codec.unmarshal(uploaded_content.data))
  82. uploaded_models[client.cid] = model
  83. uploaded_weights[client.cid] = uploaded_content.data_size
  84. uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))
  85. self.receive_param()
  86. self.set_client_uploads_train(uploaded_models, uploaded_weights, uploaded_metrics)
  87. def receive_param(self):
  88. self.uploaded_ids = []
  89. self.uploaded_grads = {}
  90. self.loss_minuses = {}
  91. for client in self.selected_clients:
  92. self.uploaded_ids.append(client.id)
  93. self.alpha_mat[client.id] = client.a_i
  94. self.uploaded_grads[client.id] = client.latest_grad
  95. self.loss_minuses[client.id] = client.loss_minus * self.mu
  96. def get_mean_grad(self):
  97. w = self.mu/self.conf.server.clients_per_round
  98. weights = [w for _ in range(self.conf.server.clients_per_round)]
  99. self.mean_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
  100. self.model_weighted_sum(self.mean_grad, list(self.uploaded_grads.values()), weights)
  101. def model_weighted_sum(self, model, models, weights):
  102. for p_m in model.parameters():
  103. p_m.data.zero_()
  104. for w, m_i in zip(weights, models):
  105. for p_m, p_i in zip(model.parameters(), m_i.parameters()):
  106. p_m.data += p_i.data.clone() * w