소스 검색

fix: add test log

shellmiao 1 년 전
부모
커밋
0ec2612bcf
1개의 변경된 파일12개의 추가작업 그리고 1개의 파일을 삭제
  1. 12 1
      applications/fedssl/client_with_pgfed.py

+ 12 - 1
applications/fedssl/client_with_pgfed.py

@@ -118,8 +118,19 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.loss_minus /= test_num
         if not self.latest_grad:
             self.latest_grad = copy.deepcopy(self.model)
+        
+        # delete later
+        all_grads_none = True
         for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
-            p_l.data = p.grad.data.clone() / len(self.train_loader)
+            if p.grad is not None:
+                p_l.data = p.grad.data.clone() / len(self.train_loader)
+                all_grads_none = False
+            else:
+                p_l.data = torch.zeros_like(p_l.data)
+
+        if all_grads_none:
+            print("All None------------------")
+        
         self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
 
         self.train_time = time.time() - start_time