clientapfl.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. from flcore.clients.clientbase import Client
  6. class clientAPFL(Client):
  7. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  8. super().__init__(args, id, train_samples, test_samples, **kwargs)
  9. self.criterion = nn.CrossEntropyLoss()
  10. self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
  11. self.alpha = args.alpha
  12. self.model_local = copy.deepcopy(self.model)
  13. self.optimizer_local = torch.optim.SGD(self.model_local.parameters(), lr=self.learning_rate)
  14. self.model_per = copy.deepcopy(self.model)
  15. self.optimizer_per = torch.optim.SGD(self.model_per.parameters(), lr=self.learning_rate)
  16. def set_parameters(self, model):
  17. for new_param, old_param, param_l, param_p in zip(model.parameters(), self.model.parameters(),
  18. self.model_local.parameters(), self.model_per.parameters()):
  19. old_param.data = new_param.data.clone()
  20. param_p.data = self.alpha * param_l.data + (1 - self.alpha) * new_param.data
  21. def train(self):
  22. trainloader = self.load_train_data()
  23. self.model.train()
  24. max_local_steps = self.local_steps
  25. # self.model_per: personalized model (v_bar), self.model: global_model (w)
  26. for step in range(max_local_steps):
  27. for i, (x, y) in enumerate(trainloader):
  28. if type(x) == type([]):
  29. x[0] = x[0].to(self.device)
  30. else:
  31. x = x.to(self.device)
  32. y = y.to(self.device)
  33. # update global model (self.model)
  34. self.optimizer.zero_grad()
  35. output = self.model(x)
  36. loss = self.criterion(output, y)
  37. loss.backward()
  38. self.optimizer.step()
  39. # update local model (self.model_local) grad_(v_bar) =
  40. self.optimizer_per.zero_grad()
  41. output_per = self.model_per(x)
  42. loss_per = self.criterion(output_per, y)
  43. loss_per.backward() # update (by gradient) model_local before updating model_per (by interpolation)
  44. # update model_local by gradient (gradient is alpha * grad(model_per))
  45. # see https://github.com/lgcollins/FedRep/blob/main/models/Update.py#L410 and the algorithm in paper
  46. self.optimizer_local.zero_grad()
  47. for p_l, p_p in zip(self.model_local.parameters(), self.model_per.parameters()):
  48. if p_l.grad is None:
  49. p_l.grad = self.alpha * p_p.grad.data.clone()
  50. else:
  51. p_l.grad.data = self.alpha * p_p.grad.data.clone()
  52. self.optimizer_local.step()
  53. # update model_per by interpolation
  54. for p_p, p_g, p_l in zip(self.model_per.parameters(), self.model.parameters(), self.model_local.parameters()):
  55. p_p.data = self.alpha * p_l.data + (1 - self.alpha) * p_g.data