|
@@ -42,7 +42,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
self.previous_trained_round = -1
|
|
|
self.weight_scaler = None
|
|
|
|
|
|
- self.latest_grad = copy.deepcopy(self.model)
|
|
|
+ self.latest_grad = None
|
|
|
self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe
|
|
|
self.prev_loss_minuses = {}
|
|
|
self.prev_mean_grad = None
|
|
@@ -58,9 +58,13 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
self.model.to(device)
|
|
|
old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
|
for i in range(conf.local_epoch):
|
|
|
+ data_count = 0 # delete later
|
|
|
batch_loss = []
|
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
+ if data_count >= 50:
|
|
|
+ break
|
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
|
+ data_count += x1.size(0)
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
@@ -91,8 +95,12 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
self.loss_minus = 0.0
|
|
|
test_num = 0
|
|
|
optimizer.zero_grad()
|
|
|
+ data_count = 0 # delete later
|
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
+ if data_count >= 50:
|
|
|
+ break
|
|
|
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
|
+ data_count += x1.size(0)
|
|
|
test_num += x1.size(0)
|
|
|
|
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
@@ -108,7 +116,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
self.loss_minus += loss.item() * x1.size(0)
|
|
|
|
|
|
self.loss_minus /= test_num
|
|
|
-
|
|
|
+ if not self.latest_grad:
|
|
|
+ self.latest_grad = copy.deepcopy(self.model)
|
|
|
for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
|
p_l.data = p.grad.data.clone() / len(self.train_loader)
|
|
|
self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
|