clientrep.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. from sklearn.preprocessing import label_binarize
  5. from sklearn import metrics
  6. import copy
  7. from flcore.clients.clientbase import Client
  8. class clientRep(Client):
  9. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  10. super().__init__(args, id, train_samples, test_samples, **kwargs)
  11. self.criterion = nn.CrossEntropyLoss()
  12. self.optimizer = torch.optim.SGD(self.model.base.parameters(), lr=self.learning_rate)
  13. self.poptimizer = torch.optim.SGD(self.model.predictor.parameters(), lr=self.learning_rate)
  14. self.plocal_steps = args.plocal_steps
  15. def train(self):
  16. trainloader = self.load_train_data()
  17. # self.model.to(self.device)
  18. self.model.train()
  19. for param in self.model.base.parameters():
  20. param.requires_grad = False
  21. for param in self.model.predictor.parameters():
  22. param.requires_grad = True
  23. for step in range(self.plocal_steps):
  24. for i, (x, y) in enumerate(trainloader):
  25. if type(x) == type([]):
  26. x[0] = x[0].to(self.device)
  27. else:
  28. x = x.to(self.device)
  29. y = y.to(self.device)
  30. self.poptimizer.zero_grad()
  31. output = self.model(x)
  32. loss = self.criterion(output, y)
  33. loss.backward()
  34. self.poptimizer.step()
  35. max_local_steps = self.local_steps
  36. for param in self.model.base.parameters():
  37. param.requires_grad = True
  38. for param in self.model.predictor.parameters():
  39. param.requires_grad = False
  40. for step in range(max_local_steps):
  41. for i, (x, y) in enumerate(trainloader):
  42. if type(x) == type([]):
  43. x[0] = x[0].to(self.device)
  44. else:
  45. x = x.to(self.device)
  46. y = y.to(self.device)
  47. self.optimizer.zero_grad()
  48. output = self.model(x)
  49. loss = self.criterion(output, y)
  50. loss.backward()
  51. self.optimizer.step()
  52. # self.model.cpu()
  53. def set_parameters(self, base):
  54. for new_param, old_param in zip(base.parameters(), self.model.base.parameters()):
  55. old_param.data = new_param.data.clone()