Explorar o código

Delete fed_optimizer.py

JayZhang42 hai 1 ano
pai
achega
a92cfb2cd9
Modificáronse 1 ficheiros con 0 adicións e 29 borrados
  1. 0 29
      fed_utils/fed_optimizer.py

+ 0 - 29
fed_utils/fed_optimizer.py

@@ -1,29 +0,0 @@
-from peft import (
-    set_peft_model_state_dict,
-)
-import torch
-import os
-from torch.nn.functional import normalize
-
-
-def FedAvg(model, selected_clients_set, output_dir, local_dataset_len_dict, epoch):
-    weights_array = normalize(
-        torch.tensor([local_dataset_len_dict[client_id] for client_id in selected_clients_set],
-                     dtype=torch.float32),
-        p=1, dim=0)
-
-    for k, client_id in enumerate(selected_clients_set):
-        single_output_dir = os.path.join(output_dir, str(epoch), "local_output_{}".format(client_id),
-                                         "pytorch_model.bin")
-        single_weights = torch.load(single_output_dir)
-        if k == 0:
-            weighted_single_weights = {key: single_weights[key] * (weights_array[k]) for key in
-                                       single_weights.keys()}
-        else:
-            weighted_single_weights = {key: weighted_single_weights[key] + single_weights[key] * (weights_array[k])
-                                       for key in
-                                       single_weights.keys()}
-
-    set_peft_model_state_dict(model, weighted_single_weights, "default")
-
-    return model