serverlocal.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from flcore.clients.clientavg import clientAVG
  2. from flcore.servers.serverbase import Server
  3. import os
  4. import logging
  5. class Local(Server):
  6. def __init__(self, args, times):
  7. super().__init__(args, times)
  8. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
  9. clientObj = clientAVG
  10. self.message_hp_dash = self.message_hp.replace(", ", "-")
  11. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  12. self.set_clients(args, clientObj)
  13. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  14. print("Finished creating server and clients.")
  15. # self.load_model()
  16. def train(self):
  17. for i in range(self.global_rounds):
  18. self.selected_clients = self.select_clients()
  19. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  20. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  21. for client in self.selected_clients:
  22. client.train()
  23. if i%self.eval_gap == 0:
  24. print(f"\n-------------Round number: {i}-------------")
  25. print("\nEvaluate local model")
  26. self.evaluate()
  27. print(f"==> Best accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  28. self.save_results(fn=self.hist_result_fn)
  29. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  30. logging.info(self.message_hp + message_res)