serverpfedme.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. import logging
  3. import copy
  4. import h5py
  5. from flcore.clients.clientpfedme import clientpFedMe
  6. from flcore.servers.serverbase import Server
  7. class pFedMe(Server):
  8. def __init__(self, args, times):
  9. super().__init__(args, times)
  10. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
  11. clientObj = clientpFedMe
  12. self.message_hp_dash = self.message_hp.replace(", ", "-")
  13. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  14. self.set_clients(args, clientObj)
  15. self.beta = args.beta
  16. self.rs_train_acc_per = []
  17. self.rs_train_loss_per = []
  18. self.rs_test_acc_per = []
  19. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  20. print("Finished creating server and clients.")
  21. def train(self):
  22. for i in range(self.global_rounds):
  23. self.selected_clients = self.select_clients()
  24. self.send_models()
  25. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  26. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  27. for client in self.selected_clients:
  28. client.train()
  29. if i%self.eval_gap == 0:
  30. print("==> Evaluating personalized models...")
  31. self.evaluate()
  32. self.previous_global_model = copy.deepcopy(list(self.global_model.parameters()))
  33. self.receive_models()
  34. self.aggregate_parameters()
  35. self.beta_aggregate_parameters()
  36. self.save_results(fn=self.hist_result_fn)
  37. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  38. logging.info(self.message_hp + message_res)
  39. # self.save_global_model()
  40. def beta_aggregate_parameters(self):
  41. # aggregate avergage model with previous model using parameter beta
  42. for pre_param, param in zip(self.previous_global_model, self.global_model.parameters()):
  43. param.data = (1 - self.beta)*pre_param.data + self.beta*param.data