@@ -90,7 +90,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
self.loss_minus = 0.0
test_num = 0
- self.optimizer.zero_grad()
+ optimizer.zero_grad()
for (batched_x1, batched_x2), _ in self.train_loader:
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
test_num += x1.size(0)