|
@@ -0,0 +1,142 @@
|
|
|
+import copy
|
|
|
+import gc
|
|
|
+import logging
|
|
|
+import time
|
|
|
+from collections import Counter
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch._utils
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+import model
|
|
|
+import utils
|
|
|
+from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
|
|
|
+from easyfl.client.base import BaseClient
|
|
|
+from easyfl.distributed.distributed import CPU
|
|
|
+
|
|
|
+from client import FedSSLClient
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+L2 = "l2"
|
|
|
+
|
|
|
+def model_dot_product(w1, w2, requires_grad=True):
|
|
|
+ """ Return the sum of squared difference between two models. """
|
|
|
+ dot_product = 0.0
|
|
|
+ for p1, p2 in zip(w1.parameters(), w2.parameters()):
|
|
|
+ if requires_grad:
|
|
|
+ dot_product += torch.sum(p1 * p2)
|
|
|
+ else:
|
|
|
+ dot_product += torch.sum(p1.data * p2.data)
|
|
|
+ return dot_product
|
|
|
+
|
|
|
+class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
+ def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
|
|
|
+ super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
|
|
|
+ self._local_model = None
|
|
|
+ self.DAPU_predictor = LOCAL
|
|
|
+ self.encoder_distance = 1
|
|
|
+ self.encoder_distances = []
|
|
|
+ self.previous_trained_round = -1
|
|
|
+ self.weight_scaler = None
|
|
|
+
|
|
|
+ self.latest_grad = copy.deepcopy(self.model)
|
|
|
+ self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe
|
|
|
+ self.prev_loss_minuses = {}
|
|
|
+ self.prev_mean_grad = None
|
|
|
+ self.prev_convex_comb_grad = None
|
|
|
+ self.a_i = None
|
|
|
+
|
|
|
+ def train(self, conf, device=CPU):
|
|
|
+ start_time = time.time()
|
|
|
+ loss_fn, optimizer = self.pretrain_setup(conf, device)
|
|
|
+ if conf.model in [model.MoCo, model.MoCoV2]:
|
|
|
+ self.model.reset_key_encoder()
|
|
|
+ self.train_loss = []
|
|
|
+ self.model.to(device)
|
|
|
+ old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
|
+ for i in range(conf.local_epoch):
|
|
|
+ batch_loss = []
|
|
|
+ for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
+ x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ if conf.model in [model.MoCo, model.MoCoV2]:
|
|
|
+ loss = self.model(x1, x2, device)
|
|
|
+ elif conf.model == model.SimCLR:
|
|
|
+ images = torch.cat((x1, x2), dim=0)
|
|
|
+ features = self.model(images)
|
|
|
+ logits, labels = self.info_nce_loss(features)
|
|
|
+ loss = loss_fn(logits, labels)
|
|
|
+ else:
|
|
|
+ loss = self.model(x1, x2)
|
|
|
+
|
|
|
+ loss.backward()
|
|
|
+ if self.prev_convex_comb_grad is not None:
|
|
|
+ for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()):
|
|
|
+ p_m.grad.data += p_prev_conv.data
|
|
|
+ dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False)
|
|
|
+ self.update_a_i(dot_prod)
|
|
|
+ optimizer.step()
|
|
|
+ batch_loss.append(loss.item())
|
|
|
+
|
|
|
+ if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
|
|
|
+ self.model.update_moving_average()
|
|
|
+
|
|
|
+ current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
|
+ self.train_loss.append(float(current_epoch_loss))
|
|
|
+
|
|
|
+ # get loss_minus and latest_grad
|
|
|
+ self.loss_minus = 0.0
|
|
|
+ test_num = 0
|
|
|
+ self.optimizer.zero_grad()
|
|
|
+ for i, (x, y) in enumerate(self.train_loader):
|
|
|
+ if type(x) == type([]):
|
|
|
+ x[0] = x[0].to(self.device)
|
|
|
+ else:
|
|
|
+ x = x.to(self.device)
|
|
|
+ y = y.to(self.device)
|
|
|
+ test_num += y.shape[0]
|
|
|
+ output = self.model(x)
|
|
|
+ loss = self.criterion(output, y)
|
|
|
+ self.loss_minus += (loss * y.shape[0]).item()
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ self.loss_minus /= test_num
|
|
|
+ for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
|
+ p_l.data = p.grad.data.clone() / len(self.train_loader)
|
|
|
+ self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
|
|
|
+
|
|
|
+ self.train_time = time.time() - start_time
|
|
|
+
|
|
|
+ # store trained model locally
|
|
|
+ self._local_model = copy.deepcopy(self.model).cpu()
|
|
|
+ self.previous_trained_round = conf.round_id
|
|
|
+ if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
|
|
|
+ new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
|
+ self.encoder_distance = self._calculate_divergence(old_model, new_model)
|
|
|
+ self.encoder_distances.append(self.encoder_distance.item())
|
|
|
+ self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
|
|
|
+ if self.conf.auto_scaler == 'y' and self.conf.random_selection:
|
|
|
+ self._calculate_weight_scaler()
|
|
|
+ if (conf.round_id + 1) % 100 == 0:
|
|
|
+ logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
|
|
|
+
|
|
|
+ def update_a_i(self, dot_prod):
|
|
|
+ for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
|
|
|
+ self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod)
|
|
|
+ self.a_i[clt_j] = max(self.a_i[clt_j], 0.0)
|
|
|
+
|
|
|
+ def set_prev_mean_grad(self, mean_grad):
|
|
|
+ if self.prev_mean_grad is None:
|
|
|
+ self.prev_mean_grad = copy.deepcopy(mean_grad)
|
|
|
+ else:
|
|
|
+ self.set_model(self.prev_mean_grad, mean_grad)
|
|
|
+
|
|
|
+ def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
|
|
|
+ if self.prev_convex_comb_grad is None:
|
|
|
+ self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
|
|
|
+ else:
|
|
|
+ self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
|