serverperavg.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import copy
  2. import os
  3. import logging
  4. import torch
  5. from flcore.clients.clientperavg import clientPerAvg
  6. from flcore.servers.serverbase import Server
  7. class PerAvg(Server):
  8. def __init__(self, args, times):
  9. super().__init__(args, times)
  10. self.message_hp = f"{args.algorithm}, alpha:{args.alpha:.5f}, beta:{args.beta:.5f}"
  11. clientObj = clientPerAvg
  12. self.message_hp_dash = self.message_hp.replace(", ", "-")
  13. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  14. self.set_clients(args, clientObj)
  15. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  16. print("Finished creating server and clients.")
  17. def train(self):
  18. for i in range(self.global_rounds):
  19. self.selected_clients = self.select_clients()
  20. # send all parameter for clients
  21. self.send_models()
  22. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  23. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  24. for client in self.selected_clients:
  25. client.train()
  26. self.receive_models()
  27. self.aggregate_parameters()
  28. if i%self.eval_gap == 0:
  29. print("==> Evaluating personalized models...", flush=True)
  30. self.send_models(mode="all")
  31. self.evaluate(self.global_model)
  32. print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  33. self.save_results(fn=self.hist_result_fn)
  34. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  35. logging.info(self.message_hp + message_res)