clientdyn.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. from flcore.clients.clientbase import Client
  6. from utils.tensor_utils import l2_squared_diff, model_dot_product
  7. class clientDyn(Client):
  8. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  9. super().__init__(args, id, train_samples, test_samples, **kwargs)
  10. self.criterion = nn.CrossEntropyLoss()
  11. self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
  12. self.alpha = args.alpha
  13. self.global_model_vector = None
  14. self.old_grad = copy.deepcopy(self.model)
  15. for p in self.old_grad.parameters():
  16. p.requires_grad = False
  17. p.data.zero_()
  18. def train(self):
  19. trainloader = self.load_train_data()
  20. # self.model.to(self.device)
  21. self.model.train()
  22. max_local_steps = self.local_steps
  23. for step in range(max_local_steps):
  24. for i, (x, y) in enumerate(trainloader):
  25. if type(x) == type([]):
  26. x[0] = x[0].to(self.device)
  27. else:
  28. x = x.to(self.device)
  29. y = y.to(self.device)
  30. self.optimizer.zero_grad()
  31. output = self.model(x)
  32. loss = self.criterion(output, y)
  33. if self.untrained_global_model != None:
  34. loss += self.alpha/2 * l2_squared_diff(self.model, self.untrained_global_model)
  35. loss -= model_dot_product(self.model, self.old_grad)
  36. loss.backward()
  37. self.optimizer.step()
  38. if self.untrained_global_model != None:
  39. for p_old_grad, p_cur, p_broadcast in zip(self.old_grad.parameters(), self.model.parameters(), self.untrained_global_model.parameters()):
  40. p_old_grad.data -= self.alpha * (p_cur.data - p_broadcast.data)
  41. # self.model.cpu()
  42. def set_parameters(self, model):
  43. for new_param, old_param in zip(model.parameters(), self.model.parameters()):
  44. old_param.data = new_param.data.clone()
  45. self.untrained_global_model = copy.deepcopy(model)
  46. for p in self.untrained_global_model.parameters():
  47. p.requires_grad = False