Sfoglia il codice sorgente

feat: modify main, add CrossEntropyLoss

shellmiao 1 anno fa
parent
commit
1855064fb2
2 ha cambiato i file con 11 aggiunte e 3 eliminazioni
  1. 2 1
      applications/fedssl/client_with_pgfed.py
  2. 9 2
      applications/fedssl/main.py

+ 2 - 1
applications/fedssl/client_with_pgfed.py

@@ -48,6 +48,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.prev_mean_grad = None
         self.prev_convex_comb_grad = None
         self.a_i = None
+        self.criterion = nn.CrossEntropyLoss()
 
     def train(self, conf, device=CPU):
         start_time = time.time()
@@ -91,7 +92,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
         # get loss_minus and latest_grad
         self.loss_minus = 0.0
         test_num = 0
-        self.optimizer.zero_grad()
+        optimizer.zero_grad()
         for i, (x, y) in enumerate(self.train_loader):
             if type(x) == type([]):
                 x[0] = x[0].to(self.device)

+ 9 - 2
applications/fedssl/main.py

@@ -58,6 +58,8 @@ def run():
     parser.add_argument('--gpu', default=0, type=int)
     parser.add_argument('--run_count', default=0, type=int)
 
+    parser.add_argument('--use_pgfed', default=False, type=bool)
+
     args = parser.parse_args()
     print("arguments: ", args)
 
@@ -164,8 +166,13 @@ def run():
 
     model = get_model(args.model, args.encoder_network, args.predictor_network)
     easyfl.register_model(model)
-    easyfl.register_client(FedSSLWithPgFedClient)
-    easyfl.register_server(FedSSLWithPgFedServer)
+    if args.use_pgfed:
+        easyfl.register_client(FedSSLWithPgFedClient)
+        easyfl.register_server(FedSSLWithPgFedServer)
+    else:
+        easyfl.register_client(FedSSLClient)
+        easyfl.register_server(FedSSLServer)
+    
     easyfl.init(config, init_all=True)
     easyfl.run()