소스 검색

fix: test

shellmiao 11 달 전
부모
커밋
d457e2486e
2개의 변경된 파일6개의 추가작업 그리고 2개의 파일을 삭제
  1. 4 2
      applications/fedssl/client_with_pgfed.py
  2. 2 0
      applications/fedssl/server_with_pgfed.py

+ 4 - 2
applications/fedssl/client_with_pgfed.py

@@ -24,8 +24,6 @@ L2 = "l2"
 
 def model_dot_product(w1, w2, requires_grad=True):
     """ Return the sum of squared difference between two models. """
-    print(w1)
-    print(w2)
     dot_product = 0.0
     for p1, p2 in zip(w1.parameters(), w2.parameters()):
         if requires_grad:
@@ -156,14 +154,18 @@ class FedSSLWithPgFedClient(FedSSLClient):
 
     def set_prev_mean_grad(self, mean_grad):
         if self.prev_mean_grad is None:
+            print("initing prev_mean_grad")
             self.prev_mean_grad = copy.deepcopy(mean_grad)
         else:
+            print("setting prev_mean_grad")
             self.set_model(self.prev_mean_grad, mean_grad)
 
     def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
         if self.prev_convex_comb_grad is None:
+            print("initing prev_convex_comb_grad")
             self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
         else:
+            print("setting prev_convex_comb_grad")
             self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
 
     def set_model(self, old_m, new_m, momentum=0.0):

+ 2 - 0
applications/fedssl/server_with_pgfed.py

@@ -69,6 +69,8 @@ class FedSSLWithPgFedServer(FedSSLServer):
         for client in self.grouped_clients:
             client.a_i = self.alpha_mat[client.id]
         if len(self.uploaded_grads) == 0:
+            print(len(self.uploaded_grads))
+            print("uploaded_grads=0")
             return
         self.convex_comb_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
         for client in self.grouped_clients: