clientpfedme.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import numpy as np
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. from flcore.optimizers.fedoptimizer import pFedMeOptimizer
  6. from flcore.clients.clientbase import Client
  7. class clientpFedMe(Client):
  8. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  9. super().__init__(args, id, train_samples, test_samples, **kwargs)
  10. self.lambdaa = args.lambdaa
  11. self.K = args.K
  12. self.personalized_learning_rate = args.p_learning_rate
  13. # these parameters are for personalized federated learing.
  14. self.local_params = copy.deepcopy(list(self.model.parameters()))
  15. self.personalized_params = copy.deepcopy(list(self.model.parameters()))
  16. self.criterion = nn.CrossEntropyLoss()
  17. def train(self):
  18. trainloader = self.load_train_data()
  19. self.model.train()
  20. max_local_steps = self.local_steps
  21. self.optimizer = pFedMeOptimizer(self.model.parameters(),
  22. local_model=self.local_params,
  23. lr=self.personalized_learning_rate,
  24. lambdaa=self.lambdaa)
  25. for step in range(max_local_steps): # local update
  26. for x, y in trainloader:
  27. if type(x) == type([]):
  28. x[0] = x[0].to(self.device)
  29. else:
  30. x = x.to(self.device)
  31. y = y.to(self.device)
  32. # K is number of personalized steps
  33. for i in range(self.K):
  34. self.optimizer.zero_grad()
  35. output = self.model(x)
  36. loss = self.criterion(output, y)
  37. loss.backward()
  38. # finding aproximate theta
  39. self.personalized_params = self.optimizer.step()
  40. # update local weight after finding aproximate theta
  41. for new_param, localweight in zip(self.personalized_params, self.local_params):
  42. localweight = localweight.to(self.device)
  43. localweight.data = localweight.data - self.lambdaa * self.learning_rate * (localweight.data - new_param.data)
  44. # self.model.cpu()
  45. self.update_parameters(self.model, self.local_params)
  46. # comment for testing on new clients
  47. def get_eval_model(self, temp_model=None):
  48. self.update_parameters(self.model, self.personalized_params)
  49. return self.model
  50. def set_parameters(self, model):
  51. for new_param, old_param, local_param in zip(model.parameters(), self.model.parameters(), self.local_params):
  52. old_param.data = new_param.data.clone()
  53. local_param.data = new_param.data.clone()