client_with_pgfed.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import copy
  2. import gc
  3. import logging
  4. import time
  5. from collections import Counter
  6. import numpy as np
  7. import torch
  8. import torch._utils
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import model
  12. import utils
  13. from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
  14. from easyfl.client.base import BaseClient
  15. from easyfl.distributed.distributed import CPU
  16. from client import FedSSLClient
  17. logger = logging.getLogger(__name__)
  18. L2 = "l2"
  19. def model_dot_product(w1, w2, requires_grad=True):
  20. """ Return the sum of squared difference between two models. """
  21. dot_product = 0.0
  22. for p1, p2 in zip(w1.parameters(), w2.parameters()):
  23. if requires_grad:
  24. dot_product += torch.sum(p1 * p2)
  25. else:
  26. dot_product += torch.sum(p1.data * p2.data)
  27. return dot_product
  28. class FedSSLWithPgFedClient(FedSSLClient):
  29. def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
  30. super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
  31. self._local_model = None
  32. self.DAPU_predictor = LOCAL
  33. self.encoder_distance = 1
  34. self.encoder_distances = []
  35. self.previous_trained_round = -1
  36. self.weight_scaler = None
  37. self.latest_grad = copy.deepcopy(self.model)
  38. self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe
  39. self.prev_loss_minuses = {}
  40. self.prev_mean_grad = None
  41. self.prev_convex_comb_grad = None
  42. self.a_i = None
  43. self.criterion = nn.CrossEntropyLoss()
  44. def train(self, conf, device=CPU):
  45. start_time = time.time()
  46. loss_fn, optimizer = self.pretrain_setup(conf, device)
  47. if conf.model in [model.MoCo, model.MoCoV2]:
  48. self.model.reset_key_encoder()
  49. self.train_loss = []
  50. self.model.to(device)
  51. old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  52. for i in range(conf.local_epoch):
  53. batch_loss = []
  54. for (batched_x1, batched_x2), _ in self.train_loader:
  55. x1, x2 = batched_x1.to(device), batched_x2.to(device)
  56. optimizer.zero_grad()
  57. if conf.model in [model.MoCo, model.MoCoV2]:
  58. loss = self.model(x1, x2, device)
  59. elif conf.model == model.SimCLR:
  60. images = torch.cat((x1, x2), dim=0)
  61. features = self.model(images)
  62. logits, labels = self.info_nce_loss(features)
  63. loss = loss_fn(logits, labels)
  64. else:
  65. loss = self.model(x1, x2)
  66. loss.backward()
  67. if self.prev_convex_comb_grad is not None:
  68. for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()):
  69. p_m.grad.data += p_prev_conv.data
  70. dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False)
  71. self.update_a_i(dot_prod)
  72. optimizer.step()
  73. batch_loss.append(loss.item())
  74. if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
  75. self.model.update_moving_average()
  76. current_epoch_loss = sum(batch_loss) / len(batch_loss)
  77. self.train_loss.append(float(current_epoch_loss))
  78. # get loss_minus and latest_grad
  79. self.loss_minus = 0.0
  80. test_num = 0
  81. optimizer.zero_grad()
  82. for i, (x, y) in enumerate(self.train_loader):
  83. if type(x) == type([]):
  84. x[0] = x[0].to(self.device)
  85. else:
  86. x = x.to(self.device)
  87. y = y.to(self.device)
  88. test_num += y.shape[0]
  89. output = self.model(x)
  90. loss = self.criterion(output, y)
  91. self.loss_minus += (loss * y.shape[0]).item()
  92. loss.backward()
  93. self.loss_minus /= test_num
  94. for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
  95. p_l.data = p.grad.data.clone() / len(self.train_loader)
  96. self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
  97. self.train_time = time.time() - start_time
  98. # store trained model locally
  99. self._local_model = copy.deepcopy(self.model).cpu()
  100. self.previous_trained_round = conf.round_id
  101. if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
  102. new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  103. self.encoder_distance = self._calculate_divergence(old_model, new_model)
  104. self.encoder_distances.append(self.encoder_distance.item())
  105. self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
  106. if self.conf.auto_scaler == 'y' and self.conf.random_selection:
  107. self._calculate_weight_scaler()
  108. if (conf.round_id + 1) % 100 == 0:
  109. logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
  110. def update_a_i(self, dot_prod):
  111. for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
  112. self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod)
  113. self.a_i[clt_j] = max(self.a_i[clt_j], 0.0)
  114. def set_prev_mean_grad(self, mean_grad):
  115. if self.prev_mean_grad is None:
  116. self.prev_mean_grad = copy.deepcopy(mean_grad)
  117. else:
  118. self.set_model(self.prev_mean_grad, mean_grad)
  119. def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
  120. if self.prev_convex_comb_grad is None:
  121. self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
  122. else:
  123. self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)