|
@@ -98,7 +98,7 @@ class FedSSLServer(BaseServer):
|
|
|
self._get_test_data()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader)
|
|
|
+ accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader, device=device)
|
|
|
|
|
|
test_results = {
|
|
|
metric.TEST_ACCURACY: float(accuracy),
|