Explorar el Código

feat: update loss_minus

shellmiao hace 1 año
padre
commit
e40523a7f3
Se han modificado 1 ficheros con 16 adiciones y 13 borrados
  1. 16 13
      applications/fedssl/client_with_pgfed.py

+ 16 - 13
applications/fedssl/client_with_pgfed.py

@@ -48,7 +48,6 @@ class FedSSLWithPgFedClient(FedSSLClient):
         self.prev_mean_grad = None
         self.prev_convex_comb_grad = None
         self.a_i = None
-        self.criterion = nn.CrossEntropyLoss()
 
     def train(self, conf, device=CPU):
         start_time = time.time()
@@ -89,23 +88,27 @@ class FedSSLWithPgFedClient(FedSSLClient):
             current_epoch_loss = sum(batch_loss) / len(batch_loss)
             self.train_loss.append(float(current_epoch_loss))
 
-        # get loss_minus and latest_grad
         self.loss_minus = 0.0
         test_num = 0
-        optimizer.zero_grad()
-        for i, (x, y) in enumerate(self.train_loader):
-            if type(x) == type([]):
-                x[0] = x[0].to(self.device)
+        self.optimizer.zero_grad()
+        for (batched_x1, batched_x2), _ in self.train_loader:
+            x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
+            test_num += x1.size(0)
+
+            if conf.model in [model.MoCo, model.MoCoV2]:
+                loss = self.model(x1, x2, device)
+            elif conf.model == model.SimCLR:
+                images = torch.cat((x1, x2), dim=0)
+                features = self.model(images)
+                logits, labels = self.info_nce_loss(features)
+                loss = loss_fn(logits, labels)
             else:
-                x = x.to(self.device)
-            y = y.to(self.device)
-            test_num += y.shape[0]
-            output = self.model(x)
-            loss = self.criterion(output, y)
-            self.loss_minus += (loss * y.shape[0]).item()
-            loss.backward()
+                loss = self.model(x1, x2)
+
+            self.loss_minus += loss.item() * x1.size(0)
 
         self.loss_minus /= test_num
+
         for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
             p_l.data = p.grad.data.clone() / len(self.train_loader)
         self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)