Browse Source

fix: AttributeError

shellmiao 1 year ago
parent
commit
3df37e6a12
1 changed files with 11 additions and 2 deletions
  1. 11 2
      applications/fedssl/client_with_pgfed.py

+ 11 - 2
applications/fedssl/client_with_pgfed.py

@@ -42,7 +42,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.previous_trained_round = -1
         self.weight_scaler = None
 
-        self.latest_grad = copy.deepcopy(self.model)
+        self.latest_grad = None
         self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe
         self.prev_loss_minuses = {}
         self.prev_mean_grad = None
@@ -58,9 +58,13 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.model.to(device)
         old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
         for i in range(conf.local_epoch):
+            data_count = 0 # delete later
             batch_loss = []
             for (batched_x1, batched_x2), _ in self.train_loader:
+                if data_count >= 50:
+                    break
                 x1, x2 = batched_x1.to(device), batched_x2.to(device)
+                data_count += x1.size(0)
                 optimizer.zero_grad()
 
                 if conf.model in [model.MoCo, model.MoCoV2]:
@@ -91,8 +95,12 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.loss_minus = 0.0
         test_num = 0
         optimizer.zero_grad()
+        data_count = 0 # delete later
         for (batched_x1, batched_x2), _ in self.train_loader:
+            if data_count >= 50:
+                break
             x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
+            data_count += x1.size(0)
             test_num += x1.size(0)
 
             if conf.model in [model.MoCo, model.MoCoV2]:
@@ -108,7 +116,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
             self.loss_minus += loss.item() * x1.size(0)
 
         self.loss_minus /= test_num
-
+        if not self.latest_grad:
+            self.latest_grad = copy.deepcopy(self.model)
         for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
             p_l.data = p.grad.data.clone() / len(self.train_loader)
         self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)