Просмотр исходного кода

[Fix] Fix FedSSL single GPU runtime. (#7)

J_BING 1 год назад
Родитель
Сommit
66bae8d2eb
3 измененных файлов с 4 добавлено и 2 удалено
  1. 2 0
      applications/fedssl/main.py
  2. 1 1
      applications/fedssl/server.py
  3. 1 1
      easyfl/coordinator.py

+ 2 - 0
applications/fedssl/main.py

@@ -149,6 +149,8 @@ def run():
             },
         }
         config.update(distribute_config)
+    else:
+        config["gpu"] = args.gpu
 
     if args.semi_supervised:
         train_data, test_data, _ = get_semi_supervised_dataset(args.dataset,

+ 1 - 1
applications/fedssl/server.py

@@ -98,7 +98,7 @@ class FedSSLServer(BaseServer):
         self._get_test_data()
 
         with torch.no_grad():
-            accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader)
+            accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader, device=device)
 
         test_results = {
             metric.TEST_ACCURACY: float(accuracy),

+ 1 - 1
easyfl/coordinator.py

@@ -91,7 +91,7 @@ class Coordinator(object):
         if self.conf.gpu == 0:
             self.conf.device = "cpu"
         elif self.conf.gpu == 1:
-            self.conf.device = 0
+            self.conf.device = "cuda"
         else:
             self.conf.device = get_device(self.conf.gpu, self.conf.distributed.world_size,
                                           self.conf.distributed.local_rank)