serverbabu.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from flcore.clients.clientbabu import clientBABU
  2. from flcore.servers.serverbase import Server
  3. import torch
  4. import os
  5. import sys
  6. import logging
  7. class FedBABU(Server):
  8. def __init__(self, args, times):
  9. super().__init__(args, times)
  10. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}, alpha:{args.alpha:.5f}"
  11. self.message_hp_dash = self.message_hp.replace(", ", "-")
  12. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  13. self.set_clients(args, clientBABU)
  14. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  15. print("Finished creating server and clients.")
  16. # self.load_model()
  17. def train(self):
  18. for i in range(self.global_rounds):
  19. self.selected_clients = self.select_clients()
  20. self.send_models()
  21. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  22. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  23. for client in self.selected_clients:
  24. client.train()
  25. self.receive_models()
  26. self.aggregate_parameters()
  27. if i%self.eval_gap == 0:
  28. print("==> Evaluating global models...", flush=True)
  29. self.send_models(mode="all")
  30. self.evaluate(mode="global")
  31. if i > 40:
  32. self.check_early_stopping()
  33. print("\n--------------------- Fine-tuning ----------------------")
  34. self.send_fine_tune_models(mode="all")
  35. for client in self.clients:
  36. client.fine_tune()
  37. print("------------- Evaluating fine-tuned models -------------")
  38. self.evaluate(mode="personalized")
  39. print(f"==> Mean personalized accuracy: {self.rs_test_acc[-1]*100:.2f}", flush=True)
  40. message_res = f"\ttest_acc:{self.rs_test_acc[-1]:.6f}"
  41. logging.info(self.message_hp + message_res)
  42. def aggregate_parameters(self):
  43. assert (len(self.uploaded_models) > 0)
  44. for param in self.global_model.parameters():
  45. param.data.zero_()
  46. for w, client_model in zip(self.uploaded_weights, self.uploaded_models):
  47. # self.uploaded_models are a list of client.model.base's
  48. self.add_parameters(w, client_model)
  49. # after self.aggregate_parameters(), the self.global_model are still a model with base and predictor
  50. def send_fine_tune_models(self, mode="selected"):
  51. if mode == "selected":
  52. assert (len(self.selected_clients) > 0)
  53. for client in self.selected_clients:
  54. client.set_fine_tune_parameters(self.global_model)
  55. elif mode == "all":
  56. for client in self.clients:
  57. client.set_fine_tune_parameters(self.global_model)
  58. else:
  59. raise NotImplementedError
  60. def receive_models(self):
  61. assert (len(self.selected_clients) > 0)
  62. self.uploaded_weights = []
  63. tot_samples = 0
  64. self.uploaded_ids = []
  65. self.uploaded_models = []
  66. for client in self.selected_clients:
  67. self.uploaded_weights.append(client.train_samples)
  68. tot_samples += client.train_samples
  69. self.uploaded_ids.append(client.id)
  70. self.uploaded_models.append(client.model.base)
  71. for i, w in enumerate(self.uploaded_weights):
  72. self.uploaded_weights[i] = w / tot_samples
  73. def load_model(self, model_path=None):
  74. if model_path is None:
  75. model_path = os.path.join("models", self.dataset)
  76. model_path = os.path.join(model_path, self.algorithm + "_server" + ".pt")
  77. assert (os.path.exists(model_path))
  78. return torch.load(model_path)