serverdyn.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import copy
  2. import torch
  3. from flcore.clients.clientdyn import clientDyn
  4. from flcore.servers.serverbase import Server
  5. import os
  6. import logging
  7. class FedDyn(Server):
  8. def __init__(self, args, times):
  9. super().__init__(args, times)
  10. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}, alpha:{args.alpha:.5f}"
  11. clientObj = clientDyn
  12. self.message_hp_dash = self.message_hp.replace(", ", "-")
  13. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  14. self.set_clients(args, clientObj)
  15. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  16. print("Finished creating server and clients.")
  17. # self.load_model()
  18. self.Budget = []
  19. self.alpha = args.alpha
  20. self.server_state = copy.deepcopy(args.model)
  21. for param in self.server_state.parameters():
  22. param.data.zero_()
  23. def train(self):
  24. for i in range(self.global_rounds):
  25. self.selected_clients = self.select_clients()
  26. self.send_models()
  27. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  28. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  29. for client in self.selected_clients:
  30. client.train()
  31. self.receive_models()
  32. self.update_server_state()
  33. self.aggregate_parameters()
  34. if i%self.eval_gap == 0:
  35. print("==> Evaluating global models...", flush=True)
  36. self.send_models(mode="all")
  37. # self.evaluate(mode="global")
  38. self.evaluate()
  39. if i == 80:
  40. self.check_early_stopping()
  41. print(f"==> Best mean global accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  42. self.save_results(fn=self.hist_result_fn)
  43. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  44. logging.info(self.message_hp + message_res)
  45. # self.save_global_model()
  46. def add_parameters(self, client_model):
  47. for server_param, client_param in zip(self.global_model.parameters(), client_model.parameters()):
  48. server_param.data += client_param.data.clone() / self.join_clients
  49. def aggregate_parameters(self):
  50. assert (len(self.uploaded_models) > 0)
  51. self.global_model = copy.deepcopy(self.uploaded_models[0])
  52. for param in self.global_model.parameters():
  53. param.data.zero_()
  54. for client_model in self.uploaded_models:
  55. self.add_parameters(client_model)
  56. for server_param, state_param in zip(self.global_model.parameters(), self.server_state.parameters()):
  57. server_param.data -= (1/self.alpha) * state_param.data
  58. def update_server_state(self):
  59. assert (len(self.uploaded_models) > 0)
  60. model_delta = copy.deepcopy(self.uploaded_models[0])
  61. for param in model_delta.parameters():
  62. param.data.zero_()
  63. for client_model in self.uploaded_models:
  64. for server_param, client_param, delta_param in zip(self.global_model.parameters(), client_model.parameters(), model_delta.parameters()):
  65. delta_param.data += (client_param - server_param) / self.num_clients
  66. for state_param, delta_param in zip(self.server_state.parameters(), model_delta.parameters()):
  67. state_param.data -= self.alpha * delta_param