|
@@ -48,7 +48,6 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
self.prev_mean_grad = None
|
|
|
self.prev_convex_comb_grad = None
|
|
|
self.a_i = None
|
|
|
- self.criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
def train(self, conf, device=CPU):
|
|
|
start_time = time.time()
|
|
@@ -89,23 +88,27 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
|
self.train_loss.append(float(current_epoch_loss))
|
|
|
|
|
|
- # get loss_minus and latest_grad
|
|
|
self.loss_minus = 0.0
|
|
|
test_num = 0
|
|
|
- optimizer.zero_grad()
|
|
|
- for i, (x, y) in enumerate(self.train_loader):
|
|
|
- if type(x) == type([]):
|
|
|
- x[0] = x[0].to(self.device)
|
|
|
+ self.optimizer.zero_grad()
|
|
|
+ for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
+ x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
|
+ test_num += x1.size(0)
|
|
|
+
|
|
|
+ if conf.model in [model.MoCo, model.MoCoV2]:
|
|
|
+ loss = self.model(x1, x2, device)
|
|
|
+ elif conf.model == model.SimCLR:
|
|
|
+ images = torch.cat((x1, x2), dim=0)
|
|
|
+ features = self.model(images)
|
|
|
+ logits, labels = self.info_nce_loss(features)
|
|
|
+ loss = loss_fn(logits, labels)
|
|
|
else:
|
|
|
- x = x.to(self.device)
|
|
|
- y = y.to(self.device)
|
|
|
- test_num += y.shape[0]
|
|
|
- output = self.model(x)
|
|
|
- loss = self.criterion(output, y)
|
|
|
- self.loss_minus += (loss * y.shape[0]).item()
|
|
|
- loss.backward()
|
|
|
+ loss = self.model(x1, x2)
|
|
|
+
|
|
|
+ self.loss_minus += loss.item() * x1.size(0)
|
|
|
|
|
|
self.loss_minus /= test_num
|
|
|
+
|
|
|
for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
|
p_l.data = p.grad.data.clone() / len(self.train_loader)
|
|
|
self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
|