serverpgfed.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import copy
  2. from flcore.clients.clientpgfed import clientPGFed
  3. from flcore.servers.serverbase import Server
  4. import numpy as np
  5. import torch
  6. import h5py
  7. import os
  8. import logging
  9. class PGFed(Server):
  10. def __init__(self, args, times):
  11. super().__init__(args, times)
  12. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}, mu:{args.mu:.5f}, lambda:{args.lambdaa:.5f}"
  13. if self.algorithm == "PGFedMo":
  14. self.momentum = args.beta
  15. self.message_hp += f", beta:{args.beta:.5f}" # momentum
  16. else:
  17. self.momentum = 0.0
  18. clientObj = clientPGFed
  19. self.message_hp_dash = self.message_hp.replace(", ", "-")
  20. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  21. self.set_clients(args, clientObj)
  22. self.mu = args.mu
  23. self.alpha_mat = (torch.ones((self.num_clients, self.num_clients)) / self.join_clients).to(self.device)
  24. self.uploaded_grads = {}
  25. self.loss_minuses = {}
  26. self.mean_grad = None
  27. self.convex_comb_grad = None
  28. self.best_global_mean_test_acc = 0.0
  29. self.rs_global_test_acc = []
  30. self.rs_global_test_auc = []
  31. self.rs_global_test_loss = []
  32. self.last_ckpt_fn = os.path.join(self.ckpt_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.pt")
  33. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  34. print("Finished creating server and clients.")
  35. def train(self):
  36. early_stop = False
  37. for i in range(self.global_rounds):
  38. self.selected_clients = self.select_clients()
  39. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  40. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  41. self.send_models()
  42. for client in self.selected_clients:
  43. client.train()
  44. self.receive_models()
  45. self.aggregate_parameters()
  46. if i%self.eval_gap == 0:
  47. print("==> Evaluating personalized models...", flush=True)
  48. self.evaluate()
  49. if i >= 40 and self.check_early_stopping():
  50. early_stop = True
  51. print("==> Performance is too low. Excecuting early stop.")
  52. break
  53. print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  54. if not early_stop:
  55. self.save_results(fn=self.hist_result_fn)
  56. # message_res = f"\tglobal_test_acc:{self.best_global_mean_test_acc:.6f}\ttest_acc:{self.best_mean_test_acc:.6f}"
  57. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  58. logging.info(self.message_hp + message_res)
  59. # state = {
  60. # "model": self.global_model.cpu().state_dict(),
  61. # # "best_global_acc": self.best_global_mean_test_acc,
  62. # "best_personalized_acc": self.best_mean_test_acc,
  63. # "alpha_mat": self.alpha_mat.cpu()
  64. # }
  65. # state.update({f"client{c.id}": c.model.cpu().state_dict() for c in self.clients})
  66. # self.save_global_model(model_path=self.last_ckpt_fn, state=state)
  67. def receive_models(self):
  68. assert (len(self.selected_clients) > 0)
  69. self.uploaded_ids = []
  70. self.uploaded_grads = {}
  71. self.loss_minuses = {}
  72. self.uploaded_models = []
  73. self.uploaded_weights = []
  74. tot_samples = 0
  75. for client in self.selected_clients:
  76. self.uploaded_ids.append(client.id)
  77. self.alpha_mat[client.id] = client.a_i
  78. self.uploaded_grads[client.id] = client.latest_grad
  79. self.loss_minuses[client.id] = client.loss_minus * self.mu
  80. self.uploaded_weights.append(client.train_samples)
  81. tot_samples += client.train_samples
  82. self.uploaded_models.append(client.model)
  83. for i, w in enumerate(self.uploaded_weights):
  84. self.uploaded_weights[i] = w / tot_samples
  85. def aggregate_parameters(self):
  86. assert (len(self.uploaded_grads) > 0)
  87. self.model_weighted_sum(self.global_model, self.uploaded_models, self.uploaded_weights)
  88. w = self.mu/self.join_clients
  89. weights = [w for _ in range(self.join_clients)]
  90. self.mean_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
  91. self.model_weighted_sum(self.mean_grad, list(self.uploaded_grads.values()), weights)
  92. def model_weighted_sum(self, model, models, weights):
  93. for p_m in model.parameters():
  94. p_m.data.zero_()
  95. for w, m_i in zip(weights, models):
  96. for p_m, p_i in zip(model.parameters(), m_i.parameters()):
  97. p_m.data += p_i.data.clone() * w
  98. def send_models(self, mode="selected"):
  99. assert (len(self.selected_clients) > 0)
  100. for client in self.selected_clients:
  101. client.a_i = self.alpha_mat[client.id]
  102. client.set_parameters(self.global_model)
  103. if len(self.uploaded_grads) == 0:
  104. return
  105. self.convex_comb_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
  106. for client in self.selected_clients:
  107. client.set_prev_mean_grad(self.mean_grad)
  108. mu_a_i = self.alpha_mat[client.id] * self.mu
  109. grads, weights = [], []
  110. for clt_idx, grad in self.uploaded_grads.items():
  111. weights.append(mu_a_i[clt_idx])
  112. grads.append(grad)
  113. self.model_weighted_sum(self.convex_comb_grad, grads, weights)
  114. client.set_prev_convex_comb_grad(self.convex_comb_grad, momentum=self.momentum)
  115. client.prev_loss_minuses = copy.deepcopy(self.loss_minuses)