serverlgfedavg.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from flcore.clients.clientlgfedavg import clientLGFedAvg
  2. from flcore.servers.serverbase import Server
  3. import copy
  4. import os
  5. import logging
  6. class LGFedAvg(Server):
  7. def __init__(self, args, times):
  8. super().__init__(args, times)
  9. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
  10. clientObj = clientLGFedAvg
  11. self.message_hp_dash = self.message_hp.replace(", ", "-")
  12. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  13. self.set_clients(args, clientObj)
  14. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  15. print("Finished creating server and clients.")
  16. # self.load_model()
  17. self.global_model = self.global_model.predictor
  18. def train(self):
  19. for i in range(self.global_rounds):
  20. self.selected_clients = self.select_clients()
  21. self.send_models()
  22. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  23. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  24. for client in self.selected_clients:
  25. client.train()
  26. self.receive_models()
  27. self.aggregate_parameters()
  28. if i%self.eval_gap == 0:
  29. print("==> Evaluating personalized models...", flush=True)
  30. self.send_models(mode="all")
  31. self.evaluate(self.global_model)
  32. print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  33. self.save_results(fn=self.hist_result_fn)
  34. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  35. logging.info(self.message_hp + message_res)
  36. def receive_models(self):
  37. assert (len(self.selected_clients) > 0)
  38. self.uploaded_weights = []
  39. tot_samples = 0
  40. self.uploaded_ids = []
  41. self.uploaded_models = []
  42. for client in self.selected_clients:
  43. self.uploaded_weights.append(client.train_samples)
  44. tot_samples += client.train_samples
  45. self.uploaded_ids.append(client.id)
  46. self.uploaded_models.append(client.model.predictor)
  47. for i, w in enumerate(self.uploaded_weights):
  48. self.uploaded_weights[i] = w / tot_samples
  49. def prepare_global_model(self):
  50. temp_model = copy.deepcopy(self.global_model) # predictor
  51. self.global_model = copy.deepcopy(self.clients[0].model)
  52. for p_t, p_g in zip(temp_model.parameters(), self.global_model.predictor.parameters()):
  53. p_g.data = p_t.data.clone()
  54. for p in self.global_model.base.parameters():
  55. p.data.zero_()
  56. for c in self.clients:
  57. for p_g, p_c in zip(self.global_model.base.parameters(), c.model.base.parameters()):
  58. p_g.data += p_c.data * c.train_samples
  59. return