|
@@ -61,8 +61,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
data_count = 0 # delete later
|
|
|
batch_loss = []
|
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
- print(data_count)
|
|
|
- if data_count >= 50:
|
|
|
+ if data_count >= 500:
|
|
|
break
|
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
|
data_count += x1.size(0)
|
|
@@ -92,21 +91,12 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
|
|
|
current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
|
self.train_loss.append(float(current_epoch_loss))
|
|
|
-
|
|
|
- # delete later
|
|
|
- all_grads_none = True
|
|
|
- for p in self.model.parameters():
|
|
|
- if p.grad is not None:
|
|
|
- all_grads_none = False
|
|
|
- if all_grads_none:
|
|
|
- print("123All None------------------")
|
|
|
|
|
|
self.loss_minus = 0.0
|
|
|
test_num = 0
|
|
|
optimizer.zero_grad()
|
|
|
data_count = 0 # delete later
|
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
- print(data_count)
|
|
|
if data_count >= 500:
|
|
|
break
|
|
|
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
@@ -130,16 +120,15 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
self.latest_grad = copy.deepcopy(self.model)
|
|
|
|
|
|
# delete later
|
|
|
- all_grads_none = True
|
|
|
- for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
|
- if p.grad is not None:
|
|
|
- p_l.data = p.grad.data.clone() / len(self.train_loader)
|
|
|
- all_grads_none = False
|
|
|
- else:
|
|
|
- p_l.data = torch.zeros_like(p_l.data)
|
|
|
-
|
|
|
- if all_grads_none:
|
|
|
- print("All None------------------")
|
|
|
+ # all_grads_none = True
|
|
|
+ # for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
|
+ # if p.grad is not None:
|
|
|
+ # p_l.data = p.grad.data.clone() / len(self.train_loader)
|
|
|
+ # all_grads_none = False
|
|
|
+ # else:
|
|
|
+ # p_l.data = torch.zeros_like(p_l.data)
|
|
|
+ # if all_grads_none:
|
|
|
+ # print("All None")
|
|
|
|
|
|
self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
|
|
|
|