fedoptimizer.py 849 B

123456789101112131415161718192021
  1. import random
  2. import torch
  3. from torch.optim import Optimizer
  4. class pFedMeOptimizer(Optimizer):
  5. def __init__(self, params, local_model=None, lr=0.01, lambdaa=0.1, mu=0.001):
  6. if lr < 0.0:
  7. raise ValueError("Invalid learning rate: {}".format(lr))
  8. defaults = dict(lr=lr, lambdaa=lambdaa, mu=mu)
  9. super(pFedMeOptimizer, self).__init__(params, defaults)
  10. self.weight_update = local_model.copy()
  11. def step(self):
  12. group = None
  13. for group in self.param_groups:
  14. for p, localweight in zip(group['params'], self.weight_update):
  15. localweight = localweight.to(p)
  16. # approximate local model
  17. p.data = p.data - group['lr'] * (p.grad.data + group['lambdaa'] * (p.data - localweight.data) + group['mu'] * p.data)
  18. return group['params']