client_with_pgfed.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import copy
  2. import gc
  3. import logging
  4. import time
  5. from collections import Counter
  6. import numpy as np
  7. import torch
  8. import torch._utils
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import model
  12. import utils
  13. from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
  14. from easyfl.client.base import BaseClient
  15. from easyfl.distributed.distributed import CPU
  16. from sklearn.neighbors import KernelDensity
  17. from client import FedSSLClient
  18. logger = logging.getLogger(__name__)
  19. L2 = "l2"
  20. def compute_all_batch_privacy(model, data_loader, device):
  21. model.eval() # 切换到评估模式
  22. batch_privacy_scores = []
  23. for batch, _ in data_loader:
  24. batch = batch.to(device)
  25. with torch.no_grad():
  26. output_features = model(batch).cpu().numpy()
  27. last_weight = list(model.parameters())[-1].detach().cpu().numpy().reshape(-1, 1)
  28. kde_output = KernelDensity(kernel='gaussian').fit(output_features)
  29. kde_weight = KernelDensity(kernel='gaussian').fit(last_weight)
  30. combined_features = np.hstack((output_features, last_weight[:output_features.shape[0], :])) # 确保形状匹配
  31. kde_combined = KernelDensity(kernel='gaussian').fit(combined_features)
  32. log_p_x = kde_output.score_samples(output_features)
  33. log_p_y = kde_weight.score_samples(last_weight[:output_features.shape[0], :])
  34. log_p_xy = kde_combined.score_samples(combined_features)
  35. privacy = np.mean(log_p_xy - log_p_x - log_p_y)
  36. batch_privacy_scores.append(privacy)
  37. # 根据互信息分数进行排序
  38. sorted_batches = np.argsort(batch_privacy_scores)
  39. return sorted_batches, batch_privacy_scores
  40. def compute_batch_privacy(model, batch, device):
  41. model.eval()
  42. batch = batch.to(device)
  43. with torch.no_grad():
  44. output_features = model(batch).cpu().numpy()
  45. last_weight = list(model.parameters())[-1].detach().cpu().numpy().reshape(-1, 1)
  46. kde_output = KernelDensity(kernel='gaussian').fit(output_features)
  47. kde_weight = KernelDensity(kernel='gaussian').fit(last_weight)
  48. combined_features = np.hstack((output_features, last_weight[:output_features.shape[0], :]))
  49. kde_combined = KernelDensity(kernel='gaussian').fit(combined_features)
  50. log_p_x = kde_output.score_samples(output_features)
  51. log_p_y = kde_weight.score_samples(last_weight[:output_features.shape[0], :])
  52. log_p_xy = kde_combined.score_samples(combined_features)
  53. privacy = np.mean(log_p_xy - log_p_x - log_p_y)
  54. return privacy
  55. def model_dot_product(w1, w2, requires_grad=True):
  56. """ Return the sum of squared difference between two models. """
  57. dot_product = 0.0
  58. for p1, p2 in zip(w1.parameters(), w2.parameters()):
  59. if requires_grad:
  60. dot_product += torch.sum(p1 * p2)
  61. else:
  62. dot_product += torch.sum(p1.data * p2.data)
  63. return dot_product
  64. class FedSSLWithPgFedClient(FedSSLClient):
  65. def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
  66. super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
  67. self._local_model = None
  68. self.DAPU_predictor = LOCAL
  69. self.encoder_distance = 1
  70. self.encoder_distances = []
  71. self.previous_trained_round = -1
  72. self.weight_scaler = None
  73. self.latest_grad = None
  74. self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe
  75. self.prev_loss_minuses = {}
  76. self.prev_mean_grad = None
  77. self.prev_convex_comb_grad = None
  78. self.a_i = None
  79. def train(self, conf, device=CPU):
  80. start_time = time.time()
  81. loss_fn, optimizer = self.pretrain_setup(conf, device)
  82. if conf.model in [model.MoCo, model.MoCoV2]:
  83. self.model.reset_key_encoder()
  84. self.train_loss = []
  85. self.model.to(device)
  86. batch_privacy_scores = []
  87. old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  88. for i in range(conf.local_epoch):
  89. batch_loss = []
  90. for (batched_x1, batched_x2), _ in self.train_loader:
  91. x1, x2 = batched_x1.to(device), batched_x2.to(device)
  92. optimizer.zero_grad()
  93. privacy_score = compute_batch_privacy(self.model, x1, device)
  94. batch_privacy_scores.append(privacy_score)
  95. if conf.model in [model.MoCo, model.MoCoV2]:
  96. loss = self.model(x1, x2, device)
  97. elif conf.model == model.SimCLR:
  98. images = torch.cat((x1, x2), dim=0)
  99. features = self.model(images)
  100. logits, labels = self.info_nce_loss(features)
  101. loss = loss_fn(logits, labels)
  102. else:
  103. loss = self.model(x1, x2)
  104. loss.backward()
  105. if self.prev_convex_comb_grad is not None:
  106. for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()):
  107. p_m.grad.data += p_prev_conv.data
  108. dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False)
  109. self.update_a_i(dot_prod)
  110. optimizer.step()
  111. batch_loss.append(loss.item())
  112. if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
  113. self.model.update_moving_average()
  114. current_epoch_loss = sum(batch_loss) / len(batch_loss)
  115. self.train_loss.append(float(current_epoch_loss))
  116. print(f"Batch privacy scores during training: {batch_privacy_scores}")
  117. print(f"在第 {i+1} 轮训练结束时, a_i 的值为: {self.a_i}")
  118. print(f"Sum of batch privacy scores during training: {sum(batch_privacy_scores)}")
  119. self.loss_minus = 0.0
  120. test_num = 0
  121. optimizer.zero_grad()
  122. for (batched_x1, batched_x2), _ in self.train_loader:
  123. x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
  124. test_num += x1.size(0)
  125. if conf.model in [model.MoCo, model.MoCoV2]:
  126. loss = self.model(x1, x2, device)
  127. elif conf.model == model.SimCLR:
  128. images = torch.cat((x1, x2), dim=0)
  129. features = self.model(images)
  130. logits, labels = self.info_nce_loss(features)
  131. loss = loss_fn(logits, labels)
  132. else:
  133. loss = self.model(x1, x2)
  134. self.loss_minus += loss.item() * x1.size(0)
  135. self.loss_minus /= test_num
  136. if not self.latest_grad:
  137. self.latest_grad = copy.deepcopy(self.model)
  138. # delete later: for test
  139. # all_grads_none = True
  140. # for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
  141. # if p.grad is not None:
  142. # p_l.data = p.grad.data.clone() / len(self.train_loader)
  143. # all_grads_none = False
  144. # else:
  145. # p_l.data = torch.zeros_like(p_l.data)
  146. # if all_grads_none:
  147. # print("All None")
  148. self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
  149. self.train_time = time.time() - start_time
  150. # store trained model locally
  151. # self._local_model = copy.deepcopy(self.model).cpu()
  152. # self.previous_trained_round = conf.round_id
  153. # if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
  154. # new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  155. # self.encoder_distance = self._calculate_divergence(old_model, new_model)
  156. # self.encoder_distances.append(self.encoder_distance.item())
  157. # self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
  158. # if self.conf.auto_scaler == 'y' and self.conf.random_selection:
  159. # self._calculate_weight_scaler()
  160. # if (conf.round_id + 1) % 100 == 0:
  161. # logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
  162. def update_a_i(self, dot_prod):
  163. for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
  164. self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod)
  165. self.a_i[clt_j] = max(self.a_i[clt_j], 0.0)
  166. def set_prev_mean_grad(self, mean_grad):
  167. if self.prev_mean_grad is None:
  168. print("Initing prev_mean_grad")
  169. self.prev_mean_grad = copy.deepcopy(mean_grad)
  170. else:
  171. print("Setting prev_mean_grad")
  172. self.set_model(self.prev_mean_grad, mean_grad)
  173. def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
  174. if self.prev_convex_comb_grad is None:
  175. print("Initing prev_convex_comb_grad")
  176. self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
  177. else:
  178. print("Setting prev_convex_comb_grad")
  179. self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
  180. def set_model(self, old_m, new_m, momentum=0.0):
  181. for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
  182. p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()