Browse Source

fix: fix AttributeError

shellmiao 11 months ago
parent
commit
ddb4ae4160

+ 1 - 0
applications/fedssl/client_with_pgfed.py

@@ -155,6 +155,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
     def set_prev_mean_grad(self, mean_grad):
         if self.prev_mean_grad is None:
             print("initing prev_mean_grad")
+            print(mean_grad)
             self.prev_mean_grad = copy.deepcopy(mean_grad)
         else:
             print("setting prev_mean_grad")

+ 1 - 0
applications/fedssl/server_with_pgfed.py

@@ -58,6 +58,7 @@ class FedSSLWithPgFedServer(FedSSLServer):
         self.send_param()
         self.distribution_to_train()
         self.aggregation()
+        self.get_mean_grad()
 
         train_time = time.time() - begin_train_time
         self.print_("Server train time: {}".format(train_time))