浏览代码

fix: fix untimeError: Boolean value of Tensor with more than one value is ambiguous

shellmiao 11 月之前
父节点
当前提交
68db222400
共有 2 个文件被更改,包括 3 次插入3 次删除
  1. 2 2
      applications/fedssl/client_with_pgfed.py
  2. 1 1
      applications/fedssl/server_with_pgfed.py

+ 2 - 2
applications/fedssl/client_with_pgfed.py

@@ -61,7 +61,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
             data_count = 0 # delete later
             batch_loss = []
             for (batched_x1, batched_x2), _ in self.train_loader:
-                if data_count >= 500:
+                if data_count >= 50:
                     break
                 x1, x2 = batched_x1.to(device), batched_x2.to(device)
                 data_count += x1.size(0)
@@ -97,7 +97,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
         optimizer.zero_grad()
         data_count = 0 # delete later
         for (batched_x1, batched_x2), _ in self.train_loader:
-            if data_count >= 500:
+            if data_count >= 50:
                 break
             x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
             data_count += x1.size(0)

+ 1 - 1
applications/fedssl/server_with_pgfed.py

@@ -64,7 +64,7 @@ class FedSSLWithPgFedServer(FedSSLServer):
         self.track(metric.TRAIN_TIME, train_time)
     
     def send_param(self):
-        if not self.alpha_mat:
+        if self.alpha_mat!=None:
             self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
         for client in self.grouped_clients:
             client.a_i = self.alpha_mat[client.id]