shellmiao 11 ماه پیش
والد
کامیت
0583c1d3ce
1فایلهای تغییر یافته به همراه11 افزوده شده و 11 حذف شده
  1. 11 11
      applications/fedssl/client_with_pgfed.py

+ 11 - 11
applications/fedssl/client_with_pgfed.py

@@ -135,17 +135,17 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.train_time = time.time() - start_time
 
         # store trained model locally
-        self._local_model = copy.deepcopy(self.model).cpu()
-        self.previous_trained_round = conf.round_id
-        if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
-            new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
-            self.encoder_distance = self._calculate_divergence(old_model, new_model)
-            self.encoder_distances.append(self.encoder_distance.item())
-            self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
-            if self.conf.auto_scaler == 'y' and self.conf.random_selection:
-                self._calculate_weight_scaler()
-            if (conf.round_id + 1) % 100 == 0:
-                logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
+        # self._local_model = copy.deepcopy(self.model).cpu()
+        # self.previous_trained_round = conf.round_id
+        # if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
+        #     new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
+        #     self.encoder_distance = self._calculate_divergence(old_model, new_model)
+        #     self.encoder_distances.append(self.encoder_distance.item())
+        #     self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
+        #     if self.conf.auto_scaler == 'y' and self.conf.random_selection:
+        #         self._calculate_weight_scaler()
+        #     if (conf.round_id + 1) % 100 == 0:
+        #         logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
 
     def update_a_i(self, dot_prod):
         for clt_j, mu_loss_minus in self.prev_loss_minuses.items():