Browse Source

update utils

JayZhang42 1 year ago
parent
commit
90d8aa6128

+ 2 - 2
fed_utils/__init__.py

@@ -1,5 +1,5 @@
-from .fed_optimizer import FedAvg
-from .client_selection import client_selection
+from .model_aggregation import FedAvg
+from .client_participation_scheduling import client_selection
 from .client import GeneralClient
 from .evaluation import global_evaluation
 from .other import other_function

+ 10 - 0
fed_utils/client_participation_scheduling.py

@@ -0,0 +1,10 @@
+import numpy as np
+
+
+def client_selection(num_clients, client_selection_frac, client_selection_strategy, other_info=None):
+    np.random.seed(other_info)
+    if client_selection_strategy == "random":
+        num_selected = max(int(client_selection_frac * num_clients), 1)
+        selected_clients_set = set(np.random.choice(np.arange(num_clients), num_selected, replace=False))
+
+    return selected_clients_set

+ 29 - 0
fed_utils/model_aggregation.py

@@ -0,0 +1,29 @@
+from peft import (
+    set_peft_model_state_dict,
+)
+import torch
+import os
+from torch.nn.functional import normalize
+
+
+def FedAvg(model, selected_clients_set, output_dir, local_dataset_len_dict, epoch):
+    weights_array = normalize(
+        torch.tensor([local_dataset_len_dict[client_id] for client_id in selected_clients_set],
+                     dtype=torch.float32),
+        p=1, dim=0)
+
+    for k, client_id in enumerate(selected_clients_set):
+        single_output_dir = os.path.join(output_dir, str(epoch), "local_output_{}".format(client_id),
+                                         "pytorch_model.bin")
+        single_weights = torch.load(single_output_dir)
+        if k == 0:
+            weighted_single_weights = {key: single_weights[key] * (weights_array[k]) for key in
+                                       single_weights.keys()}
+        else:
+            weighted_single_weights = {key: weighted_single_weights[key] + single_weights[key] * (weights_array[k])
+                                       for key in
+                                       single_weights.keys()}
+
+    set_peft_model_state_dict(model, weighted_single_weights, "default")
+
+    return model