strategies.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import copy
  2. import torch
  3. def federated_averaging(models, weights):
  4. """Compute weighted average of model parameters and persistent buffers.
  5. Using state_dict of model, including persistent buffers like BN stats.
  6. Args:
  7. models (list[nn.Module]): List of models to average.
  8. weights (list[float]): List of weights, corresponding to each model.
  9. Weights are dataset size of clients by default.
  10. Returns
  11. nn.Module: Weighted averaged model.
  12. """
  13. if models == [] or weights == []:
  14. return None
  15. model, total_weights = weighted_sum(models, weights)
  16. model_params = model.state_dict()
  17. with torch.no_grad():
  18. for name, params in model_params.items():
  19. model_params[name] = torch.div(params, total_weights)
  20. model.load_state_dict(model_params)
  21. return model
  22. def federated_averaging_only_params(models, weights):
  23. """Compute weighted average of model parameters. Use model parameters only.
  24. Args:
  25. models (list[nn.Module]): List of models to average.
  26. weights (list[float]): List of weights, corresponding to each model.
  27. Weights are dataset size of clients by default.
  28. Returns
  29. nn.Module: Weighted averaged model.
  30. """
  31. if models == [] or weights == []:
  32. return None
  33. model, total_weights = weighted_sum_only_params(models, weights)
  34. model_params = dict(model.named_parameters())
  35. with torch.no_grad():
  36. for name, params in model_params.items():
  37. model_params[name].set_(model_params[name] / total_weights)
  38. return model
  39. def weighted_sum(models, weights):
  40. """Compute weighted sum of model parameters and persistent buffers.
  41. Using state_dict of model, including persistent buffers like BN stats.
  42. Args:
  43. models (list[nn.Module]): List of models to average.
  44. weights (list[float]): List of weights, corresponding to each model.
  45. Weights are dataset size of clients by default.
  46. Returns
  47. nn.Module: Weighted averaged model.
  48. float: Sum of weights.
  49. """
  50. if models == [] or weights == []:
  51. return None
  52. model = copy.deepcopy(models[0])
  53. model_sum_params = copy.deepcopy(models[0].state_dict())
  54. with torch.no_grad():
  55. for name, params in model_sum_params.items():
  56. params *= weights[0]
  57. for i in range(1, len(models)):
  58. model_params = dict(models[i].state_dict())
  59. params += model_params[name] * weights[i]
  60. model_sum_params[name] = params
  61. model.load_state_dict(model_sum_params)
  62. return model, sum(weights)
  63. def weighted_sum_only_params(models, weights):
  64. """Compute weighted sum of model parameters. Use model parameters only.
  65. Args:
  66. models (list[nn.Module]): List of models to average.
  67. weights (list[float]): List of weights, corresponding to each model.
  68. Weights are dataset size of clients by default.
  69. Returns
  70. nn.Module: Weighted averaged model.
  71. float: Sum of weights.
  72. """
  73. if models == [] or weights == []:
  74. return None
  75. model_sum = copy.deepcopy(models[0])
  76. model_sum_params = dict(model_sum.named_parameters())
  77. with torch.no_grad():
  78. for name, params in model_sum_params.items():
  79. params *= weights[0]
  80. for i in range(1, len(models)):
  81. model_params = dict(models[i].named_parameters())
  82. params += model_params[name] * weights[i]
  83. model_sum_params[name].set_(params)
  84. return model_sum, sum(weights)
  85. def equal_weight_averaging(models):
  86. if models == []:
  87. return None
  88. model_avg = copy.deepcopy(models[0])
  89. model_avg_params = dict(model_avg.named_parameters())
  90. with torch.no_grad():
  91. for name, params in model_avg_params.items():
  92. for i in range(1, len(models)):
  93. model_params = dict(models[i].named_parameters())
  94. params += model_params[name]
  95. model_avg_params[name].set_(params / len(models))
  96. return model_avg