Kaynağa Gözat

fix: fix ZeroDivisionError

shellmiao 11 ay önce
ebeveyn
işleme
8ef104bda6

+ 6 - 3
applications/fedssl/client.py

@@ -257,9 +257,12 @@ class FedSSLClient(BaseClient):
                                                            new_dict[name].detach().clone().view(1, -1),
                                                            typ)
                 size += 1
-        distance = total_distance / size
-        logger.info(f"Model distance: {distance} = {total_distance}/{size}")
-        return distance
+        if size!=0:
+            distance = total_distance / size
+            logger.info(f"Model distance: {distance} = {total_distance}/{size}")
+            return distance
+        else:
+            return 0
 
     def _calculate_distance(self, m1, m2, typ=L2):
         if typ == L2:

+ 10 - 21
applications/fedssl/client_with_pgfed.py

@@ -61,8 +61,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
             data_count = 0 # delete later
             batch_loss = []
             for (batched_x1, batched_x2), _ in self.train_loader:
-                print(data_count)
-                if data_count >= 50:
+                if data_count >= 500:
                     break
                 x1, x2 = batched_x1.to(device), batched_x2.to(device)
                 data_count += x1.size(0)
@@ -92,21 +91,12 @@ class FedSSLWithPgFedClient(FedSSLClient):
 
             current_epoch_loss = sum(batch_loss) / len(batch_loss)
             self.train_loss.append(float(current_epoch_loss))
-        
-        # delete later
-        all_grads_none = True
-        for p in self.model.parameters():
-            if p.grad is not None:
-                all_grads_none = False
-        if all_grads_none:
-            print("123All None------------------")
 
         self.loss_minus = 0.0
         test_num = 0
         optimizer.zero_grad()
         data_count = 0 # delete later
         for (batched_x1, batched_x2), _ in self.train_loader:
-            print(data_count)
             if data_count >= 500:
                 break
             x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
@@ -130,16 +120,15 @@ class FedSSLWithPgFedClient(FedSSLClient):
             self.latest_grad = copy.deepcopy(self.model)
         
         # delete later
-        all_grads_none = True
-        for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
-            if p.grad is not None:
-                p_l.data = p.grad.data.clone() / len(self.train_loader)
-                all_grads_none = False
-            else:
-                p_l.data = torch.zeros_like(p_l.data)
-
-        if all_grads_none:
-            print("All None------------------")
+        # all_grads_none = True
+        # for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
+        #     if p.grad is not None:
+        #         p_l.data = p.grad.data.clone() / len(self.train_loader)
+        #         all_grads_none = False
+        #     else:
+        #         p_l.data = torch.zeros_like(p_l.data)
+        # if all_grads_none:
+        #     print("All None")
         
         self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)