|
@@ -118,8 +118,19 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
self.loss_minus /= test_num
|
|
|
if not self.latest_grad:
|
|
|
self.latest_grad = copy.deepcopy(self.model)
|
|
|
+
|
|
|
+ # delete later
|
|
|
+ all_grads_none = True
|
|
|
for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
|
- p_l.data = p.grad.data.clone() / len(self.train_loader)
|
|
|
+ if p.grad is not None:
|
|
|
+ p_l.data = p.grad.data.clone() / len(self.train_loader)
|
|
|
+ all_grads_none = False
|
|
|
+ else:
|
|
|
+ p_l.data = torch.zeros_like(p_l.data)
|
|
|
+
|
|
|
+ if all_grads_none:
|
|
|
+ print("All None------------------")
|
|
|
+
|
|
|
self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
|
|
|
|
|
|
self.train_time = time.time() - start_time
|