serveravg.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import time
  2. from flcore.clients.clientavg import clientAVG
  3. from flcore.servers.serverbase import Server
  4. import os
  5. import logging
  6. import torch
  7. class FedAvg(Server):
  8. def __init__(self, args, times):
  9. super().__init__(args, times)
  10. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
  11. clientObj = clientAVG
  12. self.message_hp_dash = self.message_hp.replace(", ", "-")
  13. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  14. self.last_ckpt_fn = os.path.join(self.ckpt_dir, f"FedAvg-cifar10-100clt.pt")
  15. self.set_clients(args, clientObj)
  16. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  17. print("Finished creating server and clients.")
  18. self.Budget = []
  19. def train(self):
  20. for i in range(self.global_rounds):
  21. self.selected_clients = self.select_clients()
  22. self.send_models()
  23. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  24. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  25. for client in self.selected_clients:
  26. client.train()
  27. self.receive_models()
  28. self.aggregate_parameters()
  29. if i%self.eval_gap == 0:
  30. print("==> Evaluating global models...", flush=True)
  31. self.send_models(mode="all")
  32. # self.evaluate(mode="global")
  33. self.evaluate()
  34. if i == 80:
  35. self.check_early_stopping()
  36. print(f"==> Best mean accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  37. self.save_results(fn=self.hist_result_fn)
  38. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  39. logging.info(self.message_hp + message_res)
  40. # state = {
  41. # "global_model": self.global_model.cpu().state_dict(),
  42. # "clients_test_accs": self.clients_test_accs[-1]
  43. # }
  44. # self.save_global_model(model_path=self.last_ckpt_fn, state=state)