clientperavg.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. from sklearn.preprocessing import label_binarize
  3. from sklearn import metrics
  4. import torch
  5. import copy
  6. import torch.nn as nn
  7. from flcore.clients.clientbase import Client
  8. class clientPerAvg(Client):
  9. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  10. super().__init__(args, id, train_samples, test_samples, **kwargs)
  11. self.alpha = args.alpha
  12. self.beta = args.beta
  13. self.criterion = nn.CrossEntropyLoss()
  14. self.optimizer1 = torch.optim.SGD(self.model.parameters(), lr=self.alpha)
  15. self.optimizer2 = torch.optim.SGD(self.model.parameters(), lr=self.beta)
  16. def train(self):
  17. trainloader = self.load_train_data(self.batch_size*2)
  18. self.model.train()
  19. max_local_steps = self.local_steps
  20. for step in range(max_local_steps): # local update
  21. for X, Y in trainloader:
  22. temp_model = copy.deepcopy(list(self.model.parameters()))
  23. # step 1
  24. if type(X) == type([]):
  25. x = [None, None]
  26. x[0] = X[0][:self.batch_size].to(self.device)
  27. x[1] = X[1][:self.batch_size]
  28. else:
  29. x = X[:self.batch_size].to(self.device)
  30. y = Y[:self.batch_size].to(self.device)
  31. self.optimizer1.zero_grad()
  32. output = self.model(x)
  33. loss = self.criterion(output, y)
  34. loss.backward()
  35. self.optimizer1.step()
  36. # step 2
  37. if type(X) == type([]):
  38. x = [None, None]
  39. x[0] = X[0][self.batch_size:].to(self.device)
  40. x[1] = X[1][self.batch_size:]
  41. else:
  42. x = X[self.batch_size:].to(self.device)
  43. y = Y[self.batch_size:].to(self.device)
  44. self.optimizer2.zero_grad()
  45. output = self.model(x)
  46. loss = self.criterion(output, y)
  47. loss.backward()
  48. # restore the model parameters to the one before first update
  49. for old_param, new_param in zip(self.model.parameters(), temp_model):
  50. old_param.data = new_param.data.clone()
  51. self.optimizer2.step()
  52. # self.model.cpu()
  53. def train_one_step(self):
  54. trainloader = self.load_train_data(self.batch_size)
  55. iter_trainloader = iter(trainloader)
  56. self.model.train()
  57. (x, y) = next(iter_trainloader)
  58. if type(x) == type([]):
  59. x[0] = x[0].to(self.device)
  60. else:
  61. x = x.to(self.device)
  62. y = y.to(self.device)
  63. self.optimizer2.zero_grad()
  64. output = self.model(x)
  65. loss = self.criterion(output, y)
  66. loss.backward()
  67. self.optimizer2.step()
  68. # comment for testing on new clients
  69. def test_metrics(self, temp_model=None):
  70. temp_model = copy.deepcopy(self.model)
  71. self.train_one_step()
  72. return_val = super().test_metrics(temp_model)
  73. self.clone_model(temp_model, self.model)
  74. return return_val