fed_optimizer.py 1.2 KB

1234567891011121314151617181920212223242526272829
  1. from peft import (
  2. set_peft_model_state_dict,
  3. )
  4. import torch
  5. import os
  6. from torch.nn.functional import normalize
  7. def FedAvg(model, selected_clients_set, output_dir, local_dataset_len_dict, epoch):
  8. weights_array = normalize(
  9. torch.tensor([local_dataset_len_dict[client_id] for client_id in selected_clients_set],
  10. dtype=torch.float32),
  11. p=1, dim=0)
  12. for k, client_id in enumerate(selected_clients_set):
  13. single_output_dir = os.path.join(output_dir, str(epoch), "local_output_{}".format(client_id),
  14. "pytorch_model.bin")
  15. single_weights = torch.load(single_output_dir)
  16. if k == 0:
  17. weighted_single_weights = {key: single_weights[key] * (weights_array[k]) for key in
  18. single_weights.keys()}
  19. else:
  20. weighted_single_weights = {key: weighted_single_weights[key] + single_weights[key] * (weights_array[k])
  21. for key in
  22. single_weights.keys()}
  23. set_peft_model_state_dict(model, weighted_single_weights, "default")
  24. return model