|
@@ -91,13 +91,22 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
|
|
|
current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
|
self.train_loss.append(float(current_epoch_loss))
|
|
|
+
|
|
|
+ # delete later
|
|
|
+ all_grads_none = True
|
|
|
+ for p in zip(self.model.parameters()):
|
|
|
+ if p.grad is not None:
|
|
|
+ all_grads_none = False
|
|
|
+ if all_grads_none:
|
|
|
+ print("123All None------------------")
|
|
|
|
|
|
self.loss_minus = 0.0
|
|
|
test_num = 0
|
|
|
optimizer.zero_grad()
|
|
|
data_count = 0 # delete later
|
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
- if data_count >= 50:
|
|
|
+ print(data_count)
|
|
|
+ if data_count >= 500:
|
|
|
break
|
|
|
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
|
data_count += x1.size(0)
|