Pārlūkot izejas kodu

fix :fix TypeError

shellmiao 1 gadu atpakaļ
vecāks
revīzija
33ac3a7e0e
1 mainītis faili ar 3 papildinājumiem un 1 dzēšanām
  1. 3 1
      applications/fedssl/server_with_pgfed.py

+ 3 - 1
applications/fedssl/server_with_pgfed.py

@@ -32,7 +32,7 @@ class FedSSLWithPgFedServer(FedSSLServer):
 
         self.mu = 0
         self.momentum = 0.0
-        self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
+        self.alpha_mat = None
         self.uploaded_grads = {}
         self.loss_minuses = {}
         self.mean_grad = None
@@ -57,6 +57,8 @@ class FedSSLWithPgFedServer(FedSSLServer):
         self.track(metric.TRAIN_TIME, train_time)
     
     def send_param(self):
+        if not self.alpha_mat:
+            self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
         for client in self.grouped_clients:
             client.a_i = self.alpha_mat[client.id]
         if len(self.uploaded_grads) == 0: