@@ -149,6 +149,8 @@ def run():
},
}
config.update(distribute_config)
+ else:
+ config["gpu"] = args.gpu
if args.semi_supervised:
train_data, test_data, _ = get_semi_supervised_dataset(args.dataset,
@@ -98,7 +98,7 @@ class FedSSLServer(BaseServer):
self._get_test_data()
with torch.no_grad():
- accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader)
+ accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader, device=device)
test_results = {
metric.TEST_ACCURACY: float(accuracy),
@@ -91,7 +91,7 @@ class Coordinator(object):
if self.conf.gpu == 0:
self.conf.device = "cpu"
elif self.conf.gpu == 1:
- self.conf.device = 0
+ self.conf.device = "cuda"
else:
self.conf.device = get_device(self.conf.gpu, self.conf.distributed.world_size,
self.conf.distributed.local_rank)