client.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import copy
  2. import gc
  3. import logging
  4. import time
  5. from collections import Counter
  6. import numpy as np
  7. import torch
  8. import torch._utils
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import model
  12. import utils
  13. from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
  14. from easyfl.client.base import BaseClient
  15. from easyfl.distributed.distributed import CPU
  16. logger = logging.getLogger(__name__) # 创建日志记录器
  17. L2 = "l2"
  18. class FedSSLClient(BaseClient):
  19. def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
  20. super(FedSSLClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
  21. self._local_model = None
  22. self.DAPU_predictor = LOCAL
  23. self.encoder_distance = 1
  24. self.encoder_distances = []
  25. self.previous_trained_round = -1
  26. self.weight_scaler = None
  27. self.batch_id_privacy = []
  28. self.batch_id = []
  29. def decompression(self):
  30. if self.model is None:
  31. # Initialization at beginning of the task
  32. self.model = self.compressed_model
  33. self.update_model()
  34. def update_model(self):
  35. if self.conf.model in [model.MoCo, model.MoCoV2]:
  36. self.model.encoder_q = self.compressed_model.encoder_q
  37. # self.model.encoder_k = copy.deepcopy(self._local_model.encoder_k)
  38. elif self.conf.model == model.SimCLR:
  39. self.model.online_encoder = self.compressed_model.online_encoder
  40. elif self.conf.model in [model.SimSiam, model.SimSiamNoSG]:
  41. if self._local_model is None:
  42. self.model.online_encoder = self.compressed_model.online_encoder
  43. self.model.online_predictor = self.compressed_model.online_predictor
  44. return
  45. if self.conf.update_encoder == ONLINE:
  46. online_encoder = self.compressed_model.online_encoder
  47. else:
  48. raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
  49. f"update {self.conf.update_encoder} is not supported")
  50. if self.conf.update_predictor == GLOBAL:
  51. predictor = self.compressed_model.online_predictor
  52. else:
  53. raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
  54. self.model.online_encoder = copy.deepcopy(online_encoder)
  55. self.model.online_predictor = copy.deepcopy(predictor)
  56. elif self.conf.model in [model.Symmetric, model.SymmetricNoSG]:
  57. self.model.online_encoder = self.compressed_model.online_encoder
  58. elif self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  59. if self._local_model is None:
  60. logger.info("Use aggregated encoder and predictor")
  61. self.model.online_encoder = self.compressed_model.online_encoder
  62. self.model.target_encoder = self.compressed_model.online_encoder
  63. self.model.online_predictor = self.compressed_model.online_predictor
  64. return
  65. def ema_online():
  66. self._calculate_weight_scaler()
  67. logger.info(f"Encoder: update online with EMA of global encoder @ round {self.conf.round_id}")
  68. weight = self.encoder_distance
  69. weight = min(1, self.weight_scaler * weight)
  70. weight = 1 - weight
  71. self.compressed_model = self.compressed_model.cpu()
  72. online_encoder = self.compressed_model.online_encoder # 聚合的在线编码器
  73. target_encoder = self._local_model.target_encoder
  74. ema_updater = model.EMA(weight)
  75. model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
  76. return online_encoder, target_encoder
  77. def ema_predictor():
  78. logger.info(f"Predictor: use dynamic DAPU")
  79. distance = self.encoder_distance
  80. distance = min(1, distance * self.weight_scaler)
  81. if distance > 0.5:
  82. weight = distance
  83. ema_updater = model.EMA(weight)
  84. predictor = self._local_model.online_predictor
  85. model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
  86. else:
  87. weight = 1 - distance
  88. ema_updater = model.EMA(weight)
  89. predictor = self.compressed_model.online_predictor
  90. model.update_moving_average(ema_updater, predictor, self._local_model.online_predictor)
  91. return predictor
  92. if self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == ONLINE:
  93. logger.info("Encoder: aggregate online, update online")
  94. online_encoder = self.compressed_model.online_encoder
  95. target_encoder = self._local_model.target_encoder
  96. elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == ONLINE:
  97. logger.info("Encoder: aggregate target, update online")
  98. online_encoder = self.compressed_model.target_encoder
  99. target_encoder = self._local_model.target_encoder
  100. elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == TARGET:
  101. logger.info("Encoder: aggregate target, update target")
  102. online_encoder = self._local_model.online_encoder
  103. target_encoder = self.compressed_model.target_encoder
  104. elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == TARGET:
  105. logger.info("Encoder: aggregate online, update target")
  106. online_encoder = self._local_model.online_encoder
  107. target_encoder = self.compressed_model.online_encoder
  108. elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == BOTH:
  109. logger.info("Encoder: aggregate online, update both")
  110. online_encoder = self.compressed_model.online_encoder
  111. target_encoder = self.compressed_model.online_encoder
  112. elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == BOTH:
  113. logger.info("Encoder: aggregate target, update both")
  114. online_encoder = self.compressed_model.target_encoder
  115. target_encoder = self.compressed_model.target_encoder
  116. elif self.conf.update_encoder == NONE:
  117. logger.info("Encoder: use local online and target encoders")
  118. online_encoder = self._local_model.online_encoder
  119. target_encoder = self._local_model.target_encoder
  120. elif self.conf.update_encoder == EMA:
  121. logger.info(f"Encoder: use EMA, weight {self.conf.encoder_weight}")
  122. online_encoder = self._local_model.online_encoder
  123. ema_updater = model.EMA(self.conf.encoder_weight)
  124. model.update_moving_average(ema_updater, online_encoder, self.compressed_model.online_encoder)
  125. target_encoder = self._local_model.target_encoder
  126. elif self.conf.update_encoder == DYNAMIC_EMA_ONLINE:
  127. # Use FedEMA to update online encoder
  128. online_encoder, target_encoder = ema_online()
  129. elif self.conf.update_encoder == SELECTIVE_EMA:
  130. # Use FedEMA to update online encoder
  131. # For random selection, only update with EMA when the client is selected in previous round.
  132. if self.previous_trained_round + 1 == self.conf.round_id:
  133. online_encoder, target_encoder = ema_online()
  134. else:
  135. logger.info(f"Encoder: update online and target @ round {self.conf.round_id}")
  136. online_encoder = self.compressed_model.online_encoder
  137. target_encoder = self.compressed_model.online_encoder
  138. else:
  139. raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
  140. f"update {self.conf.update_encoder} is not supported")
  141. if self.conf.update_predictor == GLOBAL:
  142. logger.info("Predictor: use global predictor")
  143. predictor = self.compressed_model.online_predictor
  144. elif self.conf.update_predictor == LOCAL:
  145. logger.info("Predictor: use local predictor")
  146. predictor = self._local_model.online_predictor
  147. elif self.conf.update_predictor == DAPU:
  148. # Divergence-aware predictor update (DAPU)
  149. logger.info(f"Predictor: use DAPU, mu {self.conf.dapu_threshold}")
  150. if self.DAPU_predictor == GLOBAL:
  151. predictor = self.compressed_model.online_predictor
  152. elif self.DAPU_predictor == LOCAL:
  153. predictor = self._local_model.online_predictor
  154. else:
  155. raise ValueError(f"Predictor: DAPU predictor can either use local or global predictor")
  156. elif self.conf.update_predictor == DYNAMIC_DAPU:
  157. # Use FedEMA to update predictor
  158. predictor = ema_predictor()
  159. elif self.conf.update_predictor == SELECTIVE_EMA:
  160. # For random selection, only update with EMA when the client is selected in previous round.
  161. if self.previous_trained_round + 1 == self.conf.round_id:
  162. predictor = ema_predictor()
  163. else:
  164. logger.info("Predictor: use global predictor")
  165. predictor = self.compressed_model.online_predictor
  166. elif self.conf.update_predictor == EMA:
  167. logger.info(f"Predictor: use EMA, weight {self.conf.predictor_weight}")
  168. predictor = self._local_model.online_predictor
  169. ema_updater = model.EMA(self.conf.predictor_weight)
  170. model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
  171. else:
  172. raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
  173. self.model.online_encoder = copy.deepcopy(online_encoder)
  174. self.model.target_encoder = copy.deepcopy(target_encoder)
  175. self.model.online_predictor = copy.deepcopy(predictor)
  176. def train(self, conf, device=CPU):
  177. start_time = time.time()
  178. loss_fn, optimizer = self.pretrain_setup(conf, device)
  179. if conf.model in [model.MoCo, model.MoCoV2]:
  180. self.model.reset_key_encoder()
  181. self.train_loss = []
  182. self.model.to(device)
  183. old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  184. for i in range(conf.local_epoch):
  185. batch_loss = []
  186. for batch_index, ((batched_x1, batched_x2), _) in enumerate(self.train_loader):
  187. # if conf.round_id > 0 and batch_index not in self.batch_id:
  188. # continue
  189. x1, x2 = batched_x1.to(device), batched_x2.to(device)
  190. optimizer.zero_grad()
  191. if conf.model in [model.MoCo, model.MoCoV2]:
  192. loss = self.model(x1, x2, device)
  193. elif conf.model == model.SimCLR:
  194. images = torch.cat((x1, x2), dim=0)
  195. features = self.model(images)
  196. logits, labels = self.info_nce_loss(features)
  197. loss = loss_fn(logits, labels)
  198. else:
  199. loss = self.model(x1, x2)
  200. loss.backward()
  201. optimizer.step()
  202. batch_loss.append(loss.item())
  203. if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
  204. self.model.update_moving_average()
  205. current_epoch_loss = sum(batch_loss) / len(batch_loss)
  206. self.train_loss.append(float(current_epoch_loss))
  207. if conf.round_id == 0 and i == 4:
  208. all_mi = utils.information(self.model.online_encoder, self.train_loader, device, self.cid)
  209. # 使用enumerate获取元素和对应的序号,并使用sorted进行排序
  210. sorted_ls_with_indices = sorted(enumerate(all_mi), key=lambda x: x[1])
  211. # 提取排序后的序号
  212. sorted_indices = [index for index, element in sorted_ls_with_indices]
  213. self.batch_id_privacy = sorted_indices # ---------------batch_privacy--------
  214. self.batch_id = self.batch_id_privacy[:int(len(self.batch_id_privacy) * 0.9)]
  215. self.train_time = time.time() - start_time
  216. # store trained model locally
  217. self._local_model = copy.deepcopy(self.model).cpu()
  218. self.previous_trained_round = conf.round_id
  219. if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
  220. new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  221. self.encoder_distance = self._calculate_divergence(old_model, new_model)
  222. self.encoder_distances.append(self.encoder_distance.item())
  223. self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
  224. if self.conf.auto_scaler == 'y' and self.conf.random_selection:
  225. self._calculate_weight_scaler()
  226. if (conf.round_id + 1) % 100 == 0:
  227. logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
  228. def _DAPU_predictor_usage(self, distance):
  229. if distance < self.conf.dapu_threshold:
  230. return GLOBAL
  231. else:
  232. return LOCAL
  233. def _calculate_divergence(self, old_model, new_model, typ=L2):
  234. size = 0
  235. total_distance = 0
  236. old_dict = old_model.state_dict()
  237. new_dict = new_model.state_dict()
  238. for name, param in old_model.named_parameters():
  239. print(name)
  240. if 'conv' in name and 'weight' in name:
  241. total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
  242. new_dict[name].detach().clone().view(1, -1),
  243. typ)
  244. size += 1
  245. distance = total_distance / size
  246. logger.info(f"Model distance: {distance} = {total_distance}/{size}")
  247. return distance
  248. def _calculate_distance(self, m1, m2, typ=L2):
  249. if typ == L2:
  250. return torch.dist(m1, m2, 2)
  251. def _calculate_weight_scaler(self):
  252. if not self.weight_scaler:
  253. if self.conf.auto_scaler == 'y':
  254. self.weight_scaler = self.conf.auto_scaler_target / self.encoder_distance
  255. else:
  256. self.weight_scaler = self.conf.weight_scaler
  257. logger.info(f"Client {self.cid}: weight scaler {self.weight_scaler}")
  258. def load_loader(self, conf):
  259. drop_last = conf.drop_last
  260. train_loader = self.train_data.loader(conf.batch_size,
  261. self.cid,
  262. shuffle=True,
  263. drop_last=drop_last,
  264. seed=conf.seed,
  265. transform=self._load_transform(conf))
  266. _print_label_count(self.cid, self.train_data.data[self.cid]['y'])
  267. return train_loader
  268. def load_optimizer(self, conf):
  269. lr = conf.optimizer.lr
  270. if conf.optimizer.lr_type == "cosine":
  271. lr = compute_lr(conf.round_id, conf.rounds, 0, conf.optimizer.lr)
  272. # movo_v1 should use the default learning rate
  273. if conf.model == model.MoCo:
  274. lr = conf.optimizer.lr
  275. params = self.model.parameters()
  276. if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  277. params = [
  278. {'params': self.model.online_encoder.parameters()},
  279. {'params': self.model.online_predictor.parameters()}
  280. ]
  281. if conf.optimizer.type == "Adam":
  282. optimizer = torch.optim.Adam(params, lr=lr)
  283. else:
  284. optimizer = torch.optim.SGD(params,
  285. lr=lr,
  286. momentum=conf.optimizer.momentum,
  287. weight_decay=conf.optimizer.weight_decay)
  288. return optimizer
  289. def _load_transform(self, conf):
  290. transformation = utils.get_transformation(conf.model)
  291. return transformation(conf.image_size, conf.gaussian)
  292. def post_upload(self):
  293. if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  294. del self.model
  295. del self.compressed_model
  296. self.model = None
  297. self.compressed_model = None
  298. assert self.model is None
  299. assert self.compressed_model is None
  300. gc.collect()
  301. torch.cuda.empty_cache()
  302. def info_nce_loss(self, features, n_views=2, temperature=0.07):
  303. labels = torch.cat([torch.arange(self.conf.batch_size) for i in range(n_views)], dim=0)
  304. labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
  305. labels = labels.to(self.device)
  306. features = F.normalize(features, dim=1)
  307. similarity_matrix = torch.matmul(features, features.T)
  308. # assert similarity_matrix.shape == (
  309. # n_views * self.conf.batch_size, n_views * self.conf.batch_size)
  310. # assert similarity_matrix.shape == labels.shape
  311. # discard the main diagonal from both: labels and similarities matrix
  312. mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
  313. labels = labels[~mask].view(labels.shape[0], -1)
  314. similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
  315. # assert similarity_matrix.shape == labels.shape
  316. # select and combine multiple positives
  317. positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
  318. # select only the negatives the negatives
  319. negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
  320. logits = torch.cat([positives, negatives], dim=1)
  321. labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
  322. logits = logits / temperature
  323. return logits, labels
  324. def compute_lr(current_round, rounds=800, eta_min=0, eta_max=0.3):
  325. """Compute learning rate as cosine decay"""
  326. pi = np.pi
  327. eta_t = eta_min + 0.5 * (eta_max - eta_min) * (np.cos(pi * current_round / rounds) + 1)
  328. return eta_t
  329. def _print_label_count(cid, labels):
  330. logger.info(f"client {cid}: {Counter(labels)}")