serverfomo.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. import copy
  3. import random
  4. import os
  5. import logging
  6. import numpy as np
  7. from flcore.clients.clientfomo import clientFomo
  8. from flcore.servers.serverbase import Server
  9. class FedFomo(Server):
  10. def __init__(self, args, times):
  11. super().__init__(args, times)
  12. self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
  13. clientObj = clientFomo
  14. self.message_hp_dash = self.message_hp.replace(", ", "-")
  15. self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
  16. self.set_clients(args, clientObj)
  17. self.P = torch.diag(torch.ones(self.num_clients, device=self.device))
  18. self.uploaded_models = [self.global_model]
  19. self.uploaded_ids = []
  20. self.M = min(args.M, self.join_clients)
  21. print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
  22. print("Finished creating server and clients.")
  23. def train(self):
  24. for i in range(self.global_rounds):
  25. self.selected_clients = self.select_clients()
  26. self.send_models()
  27. print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
  28. print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
  29. for client in self.selected_clients:
  30. client.train()
  31. self.receive_models()
  32. # self.aggregate_parameters()
  33. if i%self.eval_gap == 0:
  34. print("==> Evaluating personalized model")
  35. self.evaluate()
  36. if i == 80:
  37. self.check_early_stopping()
  38. print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  39. self.save_results(fn=self.hist_result_fn)
  40. message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
  41. logging.info(self.message_hp + message_res)
  42. # self.save_global_model()
  43. def send_models(self):
  44. assert (len(self.selected_clients) > 0)
  45. for client in self.selected_clients:
  46. if len(self.uploaded_ids) > 0:
  47. M_ = min(self.M, len(self.uploaded_models)) # if clients dropped
  48. indices = torch.topk(self.P[client.id][self.uploaded_ids], M_).indices.tolist()
  49. uploaded_ids = []
  50. uploaded_models = []
  51. for i in indices:
  52. uploaded_ids.append(self.uploaded_ids[i])
  53. uploaded_models.append(self.uploaded_models[i])
  54. client.receive_models(uploaded_ids, uploaded_models)
  55. def prepare_global_model(self):
  56. self.global_model = copy.deepcopy(self.clients[0].model)
  57. for p in self.global_model.parameters():
  58. p.data.zero_()
  59. for c in self.clients:
  60. self.add_parameters(c.train_samples, c.model)
  61. return
  62. def receive_models(self):
  63. assert (len(self.selected_clients) > 0)
  64. active_clients = random.sample(self.selected_clients, self.join_clients)
  65. self.uploaded_ids = []
  66. self.uploaded_weights = []
  67. tot_samples = 0
  68. self.uploaded_models = []
  69. for client in active_clients:
  70. self.uploaded_ids.append(client.id)
  71. self.uploaded_weights.append(client.train_samples)
  72. tot_samples += client.train_samples
  73. self.uploaded_models.append(copy.deepcopy(client.model))
  74. self.P[client.id] += client.weight_vector
  75. for i, w in enumerate(self.uploaded_weights):
  76. self.uploaded_weights[i] = w / tot_samples