|
@@ -135,17 +135,17 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
self.train_time = time.time() - start_time
|
|
self.train_time = time.time() - start_time
|
|
|
|
|
|
# store trained model locally
|
|
# store trained model locally
|
|
- self._local_model = copy.deepcopy(self.model).cpu()
|
|
|
|
- self.previous_trained_round = conf.round_id
|
|
|
|
- if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
|
|
|
|
- new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
|
|
- self.encoder_distance = self._calculate_divergence(old_model, new_model)
|
|
|
|
- self.encoder_distances.append(self.encoder_distance.item())
|
|
|
|
- self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
|
|
|
|
- if self.conf.auto_scaler == 'y' and self.conf.random_selection:
|
|
|
|
- self._calculate_weight_scaler()
|
|
|
|
- if (conf.round_id + 1) % 100 == 0:
|
|
|
|
- logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
|
|
|
|
|
|
+ # self._local_model = copy.deepcopy(self.model).cpu()
|
|
|
|
+ # self.previous_trained_round = conf.round_id
|
|
|
|
+ # if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
|
|
|
|
+ # new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
|
|
+ # self.encoder_distance = self._calculate_divergence(old_model, new_model)
|
|
|
|
+ # self.encoder_distances.append(self.encoder_distance.item())
|
|
|
|
+ # self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
|
|
|
|
+ # if self.conf.auto_scaler == 'y' and self.conf.random_selection:
|
|
|
|
+ # self._calculate_weight_scaler()
|
|
|
|
+ # if (conf.round_id + 1) % 100 == 0:
|
|
|
|
+ # logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
|
|
|
|
|
|
def update_a_i(self, dot_prod):
|
|
def update_a_i(self, dot_prod):
|
|
for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
|
|
for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
|