瀏覽代碼

fix: test

shellmiao 1 年之前
父節點
當前提交
8805a0951f
共有 1 個文件被更改,包括 10 次插入1 次删除
  1. 10 1
      applications/fedssl/client_with_pgfed.py

+ 10 - 1
applications/fedssl/client_with_pgfed.py

@@ -91,13 +91,22 @@ class FedSSLWithPgFedClient(FedSSLClient):
 
             current_epoch_loss = sum(batch_loss) / len(batch_loss)
             self.train_loss.append(float(current_epoch_loss))
+        
+        # delete later
+        all_grads_none = True
+        for p in zip(self.model.parameters()):
+            if p.grad is not None:
+                all_grads_none = False
+        if all_grads_none:
+            print("123All None------------------")
 
         self.loss_minus = 0.0
         test_num = 0
         optimizer.zero_grad()
         data_count = 0 # delete later
         for (batched_x1, batched_x2), _ in self.train_loader:
-            if data_count >= 50:
+            print(data_count)
+            if data_count >= 500:
                 break
             x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
             data_count += x1.size(0)