Explorar o código

fix: add test log

shellmiao hai 1 ano
pai
achega
0ec2612bcf
Modificáronse 1 ficheiros con 12 adicións e 1 borrados
  1. 12 1
      applications/fedssl/client_with_pgfed.py

+ 12 - 1
applications/fedssl/client_with_pgfed.py

@@ -118,8 +118,19 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.loss_minus /= test_num
         if not self.latest_grad:
             self.latest_grad = copy.deepcopy(self.model)
+        
+        # delete later
+        all_grads_none = True
         for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
-            p_l.data = p.grad.data.clone() / len(self.train_loader)
+            if p.grad is not None:
+                p_l.data = p.grad.data.clone() / len(self.train_loader)
+                all_grads_none = False
+            else:
+                p_l.data = torch.zeros_like(p_l.data)
+
+        if all_grads_none:
+            print("All None------------------")
+        
         self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
 
         self.train_time = time.time() - start_time