model_aggregation.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from peft import (
  2. set_peft_model_state_dict,
  3. )
  4. import torch
  5. import os
  6. from torch.nn.functional import normalize
  7. # 联邦平均算法
  8. def FedAvg(model, selected_clients_set, output_dir, local_dataset_len_dict, epoch):
  9. # 对各个客户端的本地数据集大小进行归一化,作为权重的基础
  10. # 这里将每个客户端的数据集大小转换为张量,并按照第0维进行 L1 归一化
  11. weights_array = normalize(
  12. torch.tensor([local_dataset_len_dict[client_id] for client_id in selected_clients_set],
  13. dtype=torch.float32),
  14. p=1, dim=0)
  15. # 遍历选定的客户端集合
  16. for k, client_id in enumerate(selected_clients_set):
  17. # 对每个选定的客户端,加载其训练得到的模型权重
  18. # 构造每个客户端权重的文件路径,并加载权重
  19. single_output_dir = os.path.join(output_dir, str(epoch), "local_output_{}".format(client_id),
  20. "pytorch_model.bin")
  21. single_weights = torch.load(single_output_dir)
  22. # 如果是第一个客户端,则将其权重乘以对应的归一化权重
  23. # 否则,将其权重乘以归一化权重后累加到总权重上
  24. if k == 0:
  25. weighted_single_weights = {key: single_weights[key] * (weights_array[k]) for key in
  26. single_weights.keys()}
  27. else:
  28. weighted_single_weights = {key: weighted_single_weights[key] + single_weights[key] * (weights_array[k])
  29. for key in
  30. single_weights.keys()}
  31. # 使用计算得到的加权平均权重,设置全局模型的状态
  32. set_peft_model_state_dict(model, weighted_single_weights, "default")
  33. return model