client.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import copy
  2. import gc
  3. import logging
  4. import time
  5. from collections import Counter
  6. import numpy as np
  7. import torch
  8. import torch._utils
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import model
  12. import utils
  13. from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
  14. from easyfl.client.base import BaseClient
  15. from easyfl.distributed.distributed import CPU
  16. logger = logging.getLogger(__name__)
  17. L2 = "l2"
  18. class FedSSLClient(BaseClient):
  19. def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
  20. super(FedSSLClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
  21. self._local_model = None
  22. self.DAPU_predictor = LOCAL
  23. self.encoder_distance = 1
  24. self.encoder_distances = []
  25. self.previous_trained_round = -1
  26. self.weight_scaler = None
  27. def decompression(self):
  28. if self.model is None:
  29. # Initialization at beginning of the task
  30. self.model = self.compressed_model
  31. self.update_model()
  32. def update_model(self):
  33. if self.conf.model in [model.MoCo, model.MoCoV2]:
  34. self.model.encoder_q = self.compressed_model.encoder_q
  35. # self.model.encoder_k = copy.deepcopy(self._local_model.encoder_k)
  36. elif self.conf.model == model.SimCLR:
  37. self.model.online_encoder = self.compressed_model.online_encoder
  38. elif self.conf.model in [model.SimSiam, model.SimSiamNoSG]:
  39. if self._local_model is None:
  40. self.model.online_encoder = self.compressed_model.online_encoder
  41. self.model.online_predictor = self.compressed_model.online_predictor
  42. return
  43. if self.conf.update_encoder == ONLINE:
  44. online_encoder = self.compressed_model.online_encoder
  45. else:
  46. raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
  47. f"update {self.conf.update_encoder} is not supported")
  48. if self.conf.update_predictor == GLOBAL:
  49. predictor = self.compressed_model.online_predictor
  50. else:
  51. raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
  52. self.model.online_encoder = copy.deepcopy(online_encoder)
  53. self.model.online_predictor = copy.deepcopy(predictor)
  54. elif self.conf.model in [model.Symmetric, model.SymmetricNoSG]:
  55. self.model.online_encoder = self.compressed_model.online_encoder
  56. elif self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  57. if self._local_model is None:
  58. logger.info("Use aggregated encoder and predictor")
  59. self.model.online_encoder = self.compressed_model.online_encoder
  60. self.model.target_encoder = self.compressed_model.online_encoder
  61. self.model.online_predictor = self.compressed_model.online_predictor
  62. return
  63. def ema_online():
  64. self._calculate_weight_scaler()
  65. logger.info(f"Encoder: update online with EMA of global encoder @ round {self.conf.round_id}")
  66. weight = self.encoder_distance
  67. weight = min(1, self.weight_scaler * weight)
  68. weight = 1 - weight
  69. self.compressed_model = self.compressed_model.cpu()
  70. online_encoder = self.compressed_model.online_encoder
  71. target_encoder = self._local_model.target_encoder
  72. ema_updater = model.EMA(weight)
  73. model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
  74. return online_encoder, target_encoder
  75. def ema_predictor():
  76. logger.info(f"Predictor: use dynamic DAPU")
  77. distance = self.encoder_distance
  78. distance = min(1, distance * self.weight_scaler)
  79. if distance > 0.5:
  80. weight = distance
  81. ema_updater = model.EMA(weight)
  82. predictor = self._local_model.online_predictor
  83. model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
  84. else:
  85. weight = 1 - distance
  86. ema_updater = model.EMA(weight)
  87. predictor = self.compressed_model.online_predictor
  88. model.update_moving_average(ema_updater, predictor, self._local_model.online_predictor)
  89. return predictor
  90. if self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == ONLINE:
  91. logger.info("Encoder: aggregate online, update online")
  92. online_encoder = self.compressed_model.online_encoder
  93. target_encoder = self._local_model.target_encoder
  94. elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == ONLINE:
  95. logger.info("Encoder: aggregate target, update online")
  96. online_encoder = self.compressed_model.target_encoder
  97. target_encoder = self._local_model.target_encoder
  98. elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == TARGET:
  99. logger.info("Encoder: aggregate target, update target")
  100. online_encoder = self._local_model.online_encoder
  101. target_encoder = self.compressed_model.target_encoder
  102. elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == TARGET:
  103. logger.info("Encoder: aggregate online, update target")
  104. online_encoder = self._local_model.online_encoder
  105. target_encoder = self.compressed_model.online_encoder
  106. elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == BOTH:
  107. logger.info("Encoder: aggregate online, update both")
  108. online_encoder = self.compressed_model.online_encoder
  109. target_encoder = self.compressed_model.online_encoder
  110. elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == BOTH:
  111. logger.info("Encoder: aggregate target, update both")
  112. online_encoder = self.compressed_model.target_encoder
  113. target_encoder = self.compressed_model.target_encoder
  114. elif self.conf.update_encoder == NONE:
  115. logger.info("Encoder: use local online and target encoders")
  116. online_encoder = self._local_model.online_encoder
  117. target_encoder = self._local_model.target_encoder
  118. elif self.conf.update_encoder == EMA:
  119. logger.info(f"Encoder: use EMA, weight {self.conf.encoder_weight}")
  120. online_encoder = self._local_model.online_encoder
  121. ema_updater = model.EMA(self.conf.encoder_weight)
  122. model.update_moving_average(ema_updater, online_encoder, self.compressed_model.online_encoder)
  123. target_encoder = self._local_model.target_encoder
  124. elif self.conf.update_encoder == DYNAMIC_EMA_ONLINE:
  125. # Use FedEMA to update online encoder
  126. online_encoder, target_encoder = ema_online()
  127. elif self.conf.update_encoder == SELECTIVE_EMA:
  128. # Use FedEMA to update online encoder
  129. # For random selection, only update with EMA when the client is selected in previous round.
  130. if self.previous_trained_round + 1 == self.conf.round_id:
  131. online_encoder, target_encoder = ema_online()
  132. else:
  133. logger.info(f"Encoder: update online and target @ round {self.conf.round_id}")
  134. online_encoder = self.compressed_model.online_encoder
  135. target_encoder = self.compressed_model.online_encoder
  136. else:
  137. raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
  138. f"update {self.conf.update_encoder} is not supported")
  139. if self.conf.update_predictor == GLOBAL:
  140. logger.info("Predictor: use global predictor")
  141. predictor = self.compressed_model.online_predictor
  142. elif self.conf.update_predictor == LOCAL:
  143. logger.info("Predictor: use local predictor")
  144. predictor = self._local_model.online_predictor
  145. elif self.conf.update_predictor == DAPU:
  146. # Divergence-aware predictor update (DAPU)
  147. logger.info(f"Predictor: use DAPU, mu {self.conf.dapu_threshold}")
  148. if self.DAPU_predictor == GLOBAL:
  149. predictor = self.compressed_model.online_predictor
  150. elif self.DAPU_predictor == LOCAL:
  151. predictor = self._local_model.online_predictor
  152. else:
  153. raise ValueError(f"Predictor: DAPU predictor can either use local or global predictor")
  154. elif self.conf.update_predictor == DYNAMIC_DAPU:
  155. # Use FedEMA to update predictor
  156. predictor = ema_predictor()
  157. elif self.conf.update_predictor == SELECTIVE_EMA:
  158. # For random selection, only update with EMA when the client is selected in previous round.
  159. if self.previous_trained_round + 1 == self.conf.round_id:
  160. predictor = ema_predictor()
  161. else:
  162. logger.info("Predictor: use global predictor")
  163. predictor = self.compressed_model.online_predictor
  164. elif self.conf.update_predictor == EMA:
  165. logger.info(f"Predictor: use EMA, weight {self.conf.predictor_weight}")
  166. predictor = self._local_model.online_predictor
  167. ema_updater = model.EMA(self.conf.predictor_weight)
  168. model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
  169. else:
  170. raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
  171. self.model.online_encoder = copy.deepcopy(online_encoder)
  172. self.model.target_encoder = copy.deepcopy(target_encoder)
  173. self.model.online_predictor = copy.deepcopy(predictor)
  174. def train(self, conf, device=CPU):
  175. start_time = time.time()
  176. loss_fn, optimizer = self.pretrain_setup(conf, device)
  177. if conf.model in [model.MoCo, model.MoCoV2]:
  178. self.model.reset_key_encoder()
  179. self.train_loss = []
  180. self.model.to(device)
  181. old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  182. for i in range(conf.local_epoch):
  183. batch_loss = []
  184. for (batched_x1, batched_x2), _ in self.train_loader:
  185. x1, x2 = batched_x1.to(device), batched_x2.to(device)
  186. optimizer.zero_grad()
  187. if conf.model in [model.MoCo, model.MoCoV2]:
  188. loss = self.model(x1, x2, device)
  189. elif conf.model == model.SimCLR:
  190. images = torch.cat((x1, x2), dim=0)
  191. features = self.model(images)
  192. logits, labels = self.info_nce_loss(features)
  193. loss = loss_fn(logits, labels)
  194. else:
  195. loss = self.model(x1, x2)
  196. loss.backward()
  197. optimizer.step()
  198. batch_loss.append(loss.item())
  199. if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
  200. self.model.update_moving_average()
  201. current_epoch_loss = sum(batch_loss) / len(batch_loss)
  202. self.train_loss.append(float(current_epoch_loss))
  203. self.train_time = time.time() - start_time
  204. # store trained model locally
  205. self._local_model = copy.deepcopy(self.model).cpu()
  206. self.previous_trained_round = conf.round_id
  207. if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
  208. new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
  209. self.encoder_distance = self._calculate_divergence(old_model, new_model)
  210. self.encoder_distances.append(self.encoder_distance.item())
  211. self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
  212. if self.conf.auto_scaler == 'y' and self.conf.random_selection:
  213. self._calculate_weight_scaler()
  214. if (conf.round_id + 1) % 100 == 0:
  215. logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
  216. def _DAPU_predictor_usage(self, distance):
  217. if distance < self.conf.dapu_threshold:
  218. return GLOBAL
  219. else:
  220. return LOCAL
  221. def _calculate_divergence(self, old_model, new_model, typ=L2):
  222. size = 0
  223. total_distance = 0
  224. old_dict = old_model.state_dict()
  225. new_dict = new_model.state_dict()
  226. for name, param in old_model.named_parameters():
  227. if 'conv' in name and 'weight' in name:
  228. total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
  229. new_dict[name].detach().clone().view(1, -1),
  230. typ)
  231. size += 1
  232. distance = total_distance / size
  233. logger.info(f"Model distance: {distance} = {total_distance}/{size}")
  234. return distance
  235. def _calculate_distance(self, m1, m2, typ=L2):
  236. if typ == L2:
  237. return torch.dist(m1, m2, 2)
  238. def _calculate_weight_scaler(self):
  239. if not self.weight_scaler:
  240. if self.conf.auto_scaler == 'y':
  241. self.weight_scaler = self.conf.auto_scaler_target / self.encoder_distance
  242. else:
  243. self.weight_scaler = self.conf.weight_scaler
  244. logger.info(f"Client {self.cid}: weight scaler {self.weight_scaler}")
  245. def load_loader(self, conf):
  246. drop_last = conf.drop_last
  247. train_loader = self.train_data.loader(conf.batch_size,
  248. self.cid,
  249. shuffle=True,
  250. drop_last=drop_last,
  251. seed=conf.seed,
  252. transform=self._load_transform(conf))
  253. _print_label_count(self.cid, self.train_data.data[self.cid]['y'])
  254. return train_loader
  255. def load_optimizer(self, conf):
  256. lr = conf.optimizer.lr
  257. if conf.optimizer.lr_type == "cosine":
  258. lr = compute_lr(conf.round_id, conf.rounds, 0, conf.optimizer.lr)
  259. # movo_v1 should use the default learning rate
  260. if conf.model == model.MoCo:
  261. lr = conf.optimizer.lr
  262. params = self.model.parameters()
  263. if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  264. params = [
  265. {'params': self.model.online_encoder.parameters()},
  266. {'params': self.model.online_predictor.parameters()}
  267. ]
  268. if conf.optimizer.type == "Adam":
  269. optimizer = torch.optim.Adam(params, lr=lr)
  270. else:
  271. optimizer = torch.optim.SGD(params,
  272. lr=lr,
  273. momentum=conf.optimizer.momentum,
  274. weight_decay=conf.optimizer.weight_decay)
  275. return optimizer
  276. def _load_transform(self, conf):
  277. transformation = utils.get_transformation(conf.model)
  278. return transformation(conf.image_size, conf.gaussian)
  279. def post_upload(self):
  280. if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
  281. del self.model
  282. del self.compressed_model
  283. self.model = None
  284. self.compressed_model = None
  285. assert self.model is None
  286. assert self.compressed_model is None
  287. gc.collect()
  288. torch.cuda.empty_cache()
  289. def info_nce_loss(self, features, n_views=2, temperature=0.07):
  290. labels = torch.cat([torch.arange(self.conf.batch_size) for i in range(n_views)], dim=0)
  291. labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
  292. labels = labels.to(self.device)
  293. features = F.normalize(features, dim=1)
  294. similarity_matrix = torch.matmul(features, features.T)
  295. # assert similarity_matrix.shape == (
  296. # n_views * self.conf.batch_size, n_views * self.conf.batch_size)
  297. # assert similarity_matrix.shape == labels.shape
  298. # discard the main diagonal from both: labels and similarities matrix
  299. mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
  300. labels = labels[~mask].view(labels.shape[0], -1)
  301. similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
  302. # assert similarity_matrix.shape == labels.shape
  303. # select and combine multiple positives
  304. positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
  305. # select only the negatives the negatives
  306. negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
  307. logits = torch.cat([positives, negatives], dim=1)
  308. labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
  309. logits = logits / temperature
  310. return logits, labels
  311. def compute_lr(current_round, rounds=800, eta_min=0, eta_max=0.3):
  312. """Compute learning rate as cosine decay"""
  313. pi = np.pi
  314. eta_t = eta_min + 0.5 * (eta_max - eta_min) * (np.cos(pi * current_round / rounds) + 1)
  315. return eta_t
  316. def _print_label_count(cid, labels):
  317. logger.info(f"client {cid}: {Counter(labels)}")