Przeglądaj źródła

feat: finish test

shellmiao 11 miesięcy temu
rodzic
commit
d35c3cb0a4
1 zmienionych plików z 5 dodań i 14 usunięć
  1. 5 14
      applications/fedssl/client_with_pgfed.py

+ 5 - 14
applications/fedssl/client_with_pgfed.py

@@ -58,13 +58,9 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.model.to(device)
         old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
         for i in range(conf.local_epoch):
-            data_count = 0 # delete later
             batch_loss = []
             for (batched_x1, batched_x2), _ in self.train_loader:
-                if data_count >= 50:
-                    break
                 x1, x2 = batched_x1.to(device), batched_x2.to(device)
-                data_count += x1.size(0)
                 optimizer.zero_grad()
 
                 if conf.model in [model.MoCo, model.MoCoV2]:
@@ -95,12 +91,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.loss_minus = 0.0
         test_num = 0
         optimizer.zero_grad()
-        data_count = 0 # delete later
         for (batched_x1, batched_x2), _ in self.train_loader:
-            if data_count >= 50:
-                break
             x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
-            data_count += x1.size(0)
             test_num += x1.size(0)
 
             if conf.model in [model.MoCo, model.MoCoV2]:
@@ -119,7 +111,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
         if not self.latest_grad:
             self.latest_grad = copy.deepcopy(self.model)
         
-        # delete later
+        # delete later: for test 
         # all_grads_none = True
         # for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
         #     if p.grad is not None:
@@ -154,19 +146,18 @@ class FedSSLWithPgFedClient(FedSSLClient):
 
     def set_prev_mean_grad(self, mean_grad):
         if self.prev_mean_grad is None:
-            print("initing prev_mean_grad")
-            print(mean_grad)
+            print("Initing prev_mean_grad")
             self.prev_mean_grad = copy.deepcopy(mean_grad)
         else:
-            print("setting prev_mean_grad")
+            print("Setting prev_mean_grad")
             self.set_model(self.prev_mean_grad, mean_grad)
 
     def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
         if self.prev_convex_comb_grad is None:
-            print("initing prev_convex_comb_grad")
+            print("Initing prev_convex_comb_grad")
             self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
         else:
-            print("setting prev_convex_comb_grad")
+            print("Setting prev_convex_comb_grad")
             self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
 
     def set_model(self, old_m, new_m, momentum=0.0):