import os import logging import copy import h5py from flcore.clients.clientpfedme import clientpFedMe from flcore.servers.serverbase import Server class pFedMe(Server): def __init__(self, args, times): super().__init__(args, times) self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}" clientObj = clientpFedMe self.message_hp_dash = self.message_hp.replace(", ", "-") self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5") self.set_clients(args, clientObj) self.beta = args.beta self.rs_train_acc_per = [] self.rs_train_loss_per = [] self.rs_test_acc_per = [] print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}") print("Finished creating server and clients.") def train(self): for i in range(self.global_rounds): self.selected_clients = self.select_clients() self.send_models() print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------") print(f"==> Training for {len(self.selected_clients)} clients...", flush=True) for client in self.selected_clients: client.train() if i%self.eval_gap == 0: print("==> Evaluating personalized models...") self.evaluate() self.previous_global_model = copy.deepcopy(list(self.global_model.parameters())) self.receive_models() self.aggregate_parameters() self.beta_aggregate_parameters() self.save_results(fn=self.hist_result_fn) message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}" logging.info(self.message_hp + message_res) # self.save_global_model() def beta_aggregate_parameters(self): # aggregate avergage model with previous model using parameter beta for pre_param, param in zip(self.previous_global_model, self.global_model.parameters()): param.data = (1 - self.beta)*pre_param.data + self.beta*param.data