clientpgfed.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import torch
  2. import numpy as np
  3. import copy
  4. import torch.nn as nn
  5. from flcore.clients.clientbase import Client
  6. from utils.tensor_utils import model_dot_product
  7. class clientPGFed(Client):
  8. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  9. super().__init__(args, id, train_samples, test_samples, **kwargs)
  10. self.criterion = nn.CrossEntropyLoss()
  11. self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
  12. self.lambdaa = args.lambdaa # /ita_2 in paper, learning rate for a_i
  13. self.latest_grad = copy.deepcopy(self.model)
  14. self.prev_loss_minuses = {}
  15. self.prev_mean_grad = None
  16. self.prev_convex_comb_grad = None
  17. self.a_i = None
  18. def train(self):
  19. trainloader = self.load_train_data()
  20. self.model.train()
  21. max_local_steps = self.local_steps
  22. for step in range(max_local_steps):
  23. for i, (x, y) in enumerate(trainloader):
  24. if type(x) == type([]):
  25. x[0] = x[0].to(self.device)
  26. else:
  27. x = x.to(self.device)
  28. y = y.to(self.device)
  29. self.optimizer.zero_grad()
  30. output = self.model(x)
  31. loss = self.criterion(output, y)
  32. loss.backward()
  33. if self.prev_convex_comb_grad is not None:
  34. for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()):
  35. p_m.grad.data += p_prev_conv.data
  36. dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False)
  37. self.update_a_i(dot_prod)
  38. self.optimizer.step()
  39. # get loss_minus and latest_grad
  40. self.loss_minus = 0.0
  41. test_num = 0
  42. self.optimizer.zero_grad()
  43. for i, (x, y) in enumerate(trainloader):
  44. if type(x) == type([]):
  45. x[0] = x[0].to(self.device)
  46. else:
  47. x = x.to(self.device)
  48. y = y.to(self.device)
  49. test_num += y.shape[0]
  50. output = self.model(x)
  51. loss = self.criterion(output, y)
  52. self.loss_minus += (loss * y.shape[0]).item()
  53. loss.backward()
  54. self.loss_minus /= test_num
  55. for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
  56. p_l.data = p.grad.data.clone() / len(trainloader)
  57. self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
  58. def get_eval_model(self, temp_model=None):
  59. model = self.model if temp_model is None else temp_model
  60. return model
  61. def update_a_i(self, dot_prod):
  62. for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
  63. self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod)
  64. self.a_i[clt_j] = max(self.a_i[clt_j], 0.0)
  65. def set_model(self, old_m, new_m, momentum=0.0):
  66. for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
  67. p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()
  68. def set_prev_mean_grad(self, mean_grad):
  69. if self.prev_mean_grad is None:
  70. self.prev_mean_grad = copy.deepcopy(mean_grad)
  71. else:
  72. self.set_model(self.prev_mean_grad, mean_grad)
  73. def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
  74. if self.prev_convex_comb_grad is None:
  75. self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
  76. else:
  77. self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)