|
@@ -0,0 +1,366 @@
|
|
|
+import copy
|
|
|
+import gc
|
|
|
+import logging
|
|
|
+import time
|
|
|
+from collections import Counter
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch._utils
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+import model
|
|
|
+import utils
|
|
|
+from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
|
|
|
+from easyfl.client.base import BaseClient
|
|
|
+from easyfl.distributed.distributed import CPU
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+L2 = "l2"
|
|
|
+
|
|
|
+
|
|
|
+class FedSSLClient(BaseClient):
|
|
|
+ def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
|
|
|
+ super(FedSSLClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
|
|
|
+ self._local_model = None
|
|
|
+ self.DAPU_predictor = LOCAL
|
|
|
+ self.encoder_distance = 1
|
|
|
+ self.encoder_distances = []
|
|
|
+ self.previous_trained_round = -1
|
|
|
+ self.weight_scaler = None
|
|
|
+
|
|
|
+ def decompression(self):
|
|
|
+ if self.model is None:
|
|
|
+ # Initialization at beginning of the task
|
|
|
+ self.model = self.compressed_model
|
|
|
+
|
|
|
+ self.update_model()
|
|
|
+
|
|
|
+ def update_model(self):
|
|
|
+ if self.conf.model in [model.MoCo, model.MoCoV2]:
|
|
|
+ self.model.encoder_q = self.compressed_model.encoder_q
|
|
|
+ # self.model.encoder_k = copy.deepcopy(self._local_model.encoder_k)
|
|
|
+ elif self.conf.model == model.SimCLR:
|
|
|
+ self.model.online_encoder = self.compressed_model.online_encoder
|
|
|
+ elif self.conf.model in [model.SimSiam, model.SimSiamNoSG]:
|
|
|
+ if self._local_model is None:
|
|
|
+ self.model.online_encoder = self.compressed_model.online_encoder
|
|
|
+ self.model.online_predictor = self.compressed_model.online_predictor
|
|
|
+ return
|
|
|
+
|
|
|
+ if self.conf.update_encoder == ONLINE:
|
|
|
+ online_encoder = self.compressed_model.online_encoder
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
|
|
|
+ f"update {self.conf.update_encoder} is not supported")
|
|
|
+
|
|
|
+ if self.conf.update_predictor == GLOBAL:
|
|
|
+ predictor = self.compressed_model.online_predictor
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
|
|
|
+
|
|
|
+ self.model.online_encoder = copy.deepcopy(online_encoder)
|
|
|
+ self.model.online_predictor = copy.deepcopy(predictor)
|
|
|
+
|
|
|
+ elif self.conf.model in [model.Symmetric, model.SymmetricNoSG]:
|
|
|
+ self.model.online_encoder = self.compressed_model.online_encoder
|
|
|
+
|
|
|
+ elif self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
|
|
|
+
|
|
|
+ if self._local_model is None:
|
|
|
+ logger.info("Use aggregated encoder and predictor")
|
|
|
+ self.model.online_encoder = self.compressed_model.online_encoder
|
|
|
+ self.model.target_encoder = self.compressed_model.online_encoder
|
|
|
+ self.model.online_predictor = self.compressed_model.online_predictor
|
|
|
+ return
|
|
|
+
|
|
|
+ def ema_online():
|
|
|
+ self._calculate_weight_scaler()
|
|
|
+ logger.info(f"Encoder: update online with EMA of global encoder @ round {self.conf.round_id}")
|
|
|
+ weight = self.encoder_distance
|
|
|
+ weight = min(1, self.weight_scaler * weight)
|
|
|
+ weight = 1 - weight
|
|
|
+ self.compressed_model = self.compressed_model.cpu()
|
|
|
+ online_encoder = self.compressed_model.online_encoder
|
|
|
+ target_encoder = self._local_model.target_encoder
|
|
|
+ ema_updater = model.EMA(weight)
|
|
|
+ model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
|
|
|
+ return online_encoder, target_encoder
|
|
|
+
|
|
|
+ def ema_predictor():
|
|
|
+ logger.info(f"Predictor: use dynamic DAPU")
|
|
|
+ distance = self.encoder_distance
|
|
|
+ distance = min(1, distance * self.weight_scaler)
|
|
|
+ if distance > 0.5:
|
|
|
+ weight = distance
|
|
|
+ ema_updater = model.EMA(weight)
|
|
|
+ predictor = self._local_model.online_predictor
|
|
|
+ model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
|
|
|
+ else:
|
|
|
+ weight = 1 - distance
|
|
|
+ ema_updater = model.EMA(weight)
|
|
|
+ predictor = self.compressed_model.online_predictor
|
|
|
+ model.update_moving_average(ema_updater, predictor, self._local_model.online_predictor)
|
|
|
+ return predictor
|
|
|
+
|
|
|
+ if self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == ONLINE:
|
|
|
+ logger.info("Encoder: aggregate online, update online")
|
|
|
+ online_encoder = self.compressed_model.online_encoder
|
|
|
+ target_encoder = self._local_model.target_encoder
|
|
|
+ elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == ONLINE:
|
|
|
+ logger.info("Encoder: aggregate target, update online")
|
|
|
+ online_encoder = self.compressed_model.target_encoder
|
|
|
+ target_encoder = self._local_model.target_encoder
|
|
|
+ elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == TARGET:
|
|
|
+ logger.info("Encoder: aggregate target, update target")
|
|
|
+ online_encoder = self._local_model.online_encoder
|
|
|
+ target_encoder = self.compressed_model.target_encoder
|
|
|
+ elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == TARGET:
|
|
|
+ logger.info("Encoder: aggregate online, update target")
|
|
|
+ online_encoder = self._local_model.online_encoder
|
|
|
+ target_encoder = self.compressed_model.online_encoder
|
|
|
+ elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == BOTH:
|
|
|
+ logger.info("Encoder: aggregate online, update both")
|
|
|
+ online_encoder = self.compressed_model.online_encoder
|
|
|
+ target_encoder = self.compressed_model.online_encoder
|
|
|
+ elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == BOTH:
|
|
|
+ logger.info("Encoder: aggregate target, update both")
|
|
|
+ online_encoder = self.compressed_model.target_encoder
|
|
|
+ target_encoder = self.compressed_model.target_encoder
|
|
|
+ elif self.conf.update_encoder == NONE:
|
|
|
+ logger.info("Encoder: use local online and target encoders")
|
|
|
+ online_encoder = self._local_model.online_encoder
|
|
|
+ target_encoder = self._local_model.target_encoder
|
|
|
+ elif self.conf.update_encoder == EMA:
|
|
|
+ logger.info(f"Encoder: use EMA, weight {self.conf.encoder_weight}")
|
|
|
+ online_encoder = self._local_model.online_encoder
|
|
|
+ ema_updater = model.EMA(self.conf.encoder_weight)
|
|
|
+ model.update_moving_average(ema_updater, online_encoder, self.compressed_model.online_encoder)
|
|
|
+ target_encoder = self._local_model.target_encoder
|
|
|
+ elif self.conf.update_encoder == DYNAMIC_EMA_ONLINE:
|
|
|
+ # Use FedEMA to update online encoder
|
|
|
+ online_encoder, target_encoder = ema_online()
|
|
|
+ elif self.conf.update_encoder == SELECTIVE_EMA:
|
|
|
+ # Use FedEMA to update online encoder
|
|
|
+ # For random selection, only update with EMA when the client is selected in previous round.
|
|
|
+ if self.previous_trained_round + 1 == self.conf.round_id:
|
|
|
+ online_encoder, target_encoder = ema_online()
|
|
|
+ else:
|
|
|
+ logger.info(f"Encoder: update online and target @ round {self.conf.round_id}")
|
|
|
+ online_encoder = self.compressed_model.online_encoder
|
|
|
+ target_encoder = self.compressed_model.online_encoder
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
|
|
|
+ f"update {self.conf.update_encoder} is not supported")
|
|
|
+
|
|
|
+ if self.conf.update_predictor == GLOBAL:
|
|
|
+ logger.info("Predictor: use global predictor")
|
|
|
+ predictor = self.compressed_model.online_predictor
|
|
|
+ elif self.conf.update_predictor == LOCAL:
|
|
|
+ logger.info("Predictor: use local predictor")
|
|
|
+ predictor = self._local_model.online_predictor
|
|
|
+ elif self.conf.update_predictor == DAPU:
|
|
|
+ # Divergence-aware predictor update (DAPU)
|
|
|
+ logger.info(f"Predictor: use DAPU, mu {self.conf.dapu_threshold}")
|
|
|
+ if self.DAPU_predictor == GLOBAL:
|
|
|
+ predictor = self.compressed_model.online_predictor
|
|
|
+ elif self.DAPU_predictor == LOCAL:
|
|
|
+ predictor = self._local_model.online_predictor
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Predictor: DAPU predictor can either use local or global predictor")
|
|
|
+ elif self.conf.update_predictor == DYNAMIC_DAPU:
|
|
|
+ # Use FedEMA to update predictor
|
|
|
+ predictor = ema_predictor()
|
|
|
+ elif self.conf.update_predictor == SELECTIVE_EMA:
|
|
|
+ # For random selection, only update with EMA when the client is selected in previous round.
|
|
|
+ if self.previous_trained_round + 1 == self.conf.round_id:
|
|
|
+ predictor = ema_predictor()
|
|
|
+ else:
|
|
|
+ logger.info("Predictor: use global predictor")
|
|
|
+ predictor = self.compressed_model.online_predictor
|
|
|
+ elif self.conf.update_predictor == EMA:
|
|
|
+ logger.info(f"Predictor: use EMA, weight {self.conf.predictor_weight}")
|
|
|
+ predictor = self._local_model.online_predictor
|
|
|
+ ema_updater = model.EMA(self.conf.predictor_weight)
|
|
|
+ model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
|
|
|
+
|
|
|
+ self.model.online_encoder = copy.deepcopy(online_encoder)
|
|
|
+ self.model.target_encoder = copy.deepcopy(target_encoder)
|
|
|
+ self.model.online_predictor = copy.deepcopy(predictor)
|
|
|
+
|
|
|
+ def train(self, conf, device=CPU):
|
|
|
+ start_time = time.time()
|
|
|
+ loss_fn, optimizer = self.pretrain_setup(conf, device)
|
|
|
+ if conf.model in [model.MoCo, model.MoCoV2]:
|
|
|
+ self.model.reset_key_encoder()
|
|
|
+ self.train_loss = []
|
|
|
+ self.model.to(device)
|
|
|
+ old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
|
+ for i in range(conf.local_epoch):
|
|
|
+ batch_loss = []
|
|
|
+ for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
+ x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ if conf.model in [model.MoCo, model.MoCoV2]:
|
|
|
+ loss = self.model(x1, x2, device)
|
|
|
+ elif conf.model == model.SimCLR:
|
|
|
+ images = torch.cat((x1, x2), dim=0)
|
|
|
+ features = self.model(images)
|
|
|
+ logits, labels = self.info_nce_loss(features)
|
|
|
+ loss = loss_fn(logits, labels)
|
|
|
+ else:
|
|
|
+ loss = self.model(x1, x2)
|
|
|
+
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ batch_loss.append(loss.item())
|
|
|
+
|
|
|
+ if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
|
|
|
+ self.model.update_moving_average()
|
|
|
+
|
|
|
+ current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
|
+ self.train_loss.append(float(current_epoch_loss))
|
|
|
+ self.train_time = time.time() - start_time
|
|
|
+
|
|
|
+ # store trained model locally
|
|
|
+ self._local_model = copy.deepcopy(self.model).cpu()
|
|
|
+ self.previous_trained_round = conf.round_id
|
|
|
+ if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
|
|
|
+ new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
|
+ self.encoder_distance = self._calculate_divergence(old_model, new_model)
|
|
|
+ self.encoder_distances.append(self.encoder_distance.item())
|
|
|
+ self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
|
|
|
+ if self.conf.auto_scaler == 'y' and self.conf.random_selection:
|
|
|
+ self._calculate_weight_scaler()
|
|
|
+ if (conf.round_id + 1) % 100 == 0:
|
|
|
+ logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
|
|
|
+
|
|
|
+ def _DAPU_predictor_usage(self, distance):
|
|
|
+ if distance < self.conf.dapu_threshold:
|
|
|
+ return GLOBAL
|
|
|
+ else:
|
|
|
+ return LOCAL
|
|
|
+
|
|
|
+ def _calculate_divergence(self, old_model, new_model, typ=L2):
|
|
|
+ size = 0
|
|
|
+ total_distance = 0
|
|
|
+ old_dict = old_model.state_dict()
|
|
|
+ new_dict = new_model.state_dict()
|
|
|
+ for name, param in old_model.named_parameters():
|
|
|
+ if 'conv' in name and 'weight' in name:
|
|
|
+ total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
|
|
|
+ new_dict[name].detach().clone().view(1, -1),
|
|
|
+ typ)
|
|
|
+ size += 1
|
|
|
+ distance = total_distance / size
|
|
|
+ logger.info(f"Model distance: {distance} = {total_distance}/{size}")
|
|
|
+ return distance
|
|
|
+
|
|
|
+ def _calculate_distance(self, m1, m2, typ=L2):
|
|
|
+ if typ == L2:
|
|
|
+ return torch.dist(m1, m2, 2)
|
|
|
+
|
|
|
+ def _calculate_weight_scaler(self):
|
|
|
+ if not self.weight_scaler:
|
|
|
+ if self.conf.auto_scaler == 'y':
|
|
|
+ self.weight_scaler = self.conf.auto_scaler_target / self.encoder_distance
|
|
|
+ else:
|
|
|
+ self.weight_scaler = self.conf.weight_scaler
|
|
|
+ logger.info(f"Client {self.cid}: weight scaler {self.weight_scaler}")
|
|
|
+
|
|
|
+ def load_loader(self, conf):
|
|
|
+ drop_last = conf.drop_last
|
|
|
+ train_loader = self.train_data.loader(conf.batch_size,
|
|
|
+ self.cid,
|
|
|
+ shuffle=True,
|
|
|
+ drop_last=drop_last,
|
|
|
+ seed=conf.seed,
|
|
|
+ transform=self._load_transform(conf))
|
|
|
+ _print_label_count(self.cid, self.train_data.data[self.cid]['y'])
|
|
|
+ return train_loader
|
|
|
+
|
|
|
+ def load_optimizer(self, conf):
|
|
|
+ lr = conf.optimizer.lr
|
|
|
+ if conf.optimizer.lr_type == "cosine":
|
|
|
+ lr = compute_lr(conf.round_id, conf.rounds, 0, conf.optimizer.lr)
|
|
|
+
|
|
|
+ # movo_v1 should use the default learning rate
|
|
|
+ if conf.model == model.MoCo:
|
|
|
+ lr = conf.optimizer.lr
|
|
|
+
|
|
|
+ params = self.model.parameters()
|
|
|
+ if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
|
|
|
+ params = [
|
|
|
+ {'params': self.model.online_encoder.parameters()},
|
|
|
+ {'params': self.model.online_predictor.parameters()}
|
|
|
+ ]
|
|
|
+
|
|
|
+ if conf.optimizer.type == "Adam":
|
|
|
+ optimizer = torch.optim.Adam(params, lr=lr)
|
|
|
+ else:
|
|
|
+ optimizer = torch.optim.SGD(params,
|
|
|
+ lr=lr,
|
|
|
+ momentum=conf.optimizer.momentum,
|
|
|
+ weight_decay=conf.optimizer.weight_decay)
|
|
|
+ return optimizer
|
|
|
+
|
|
|
+ def _load_transform(self, conf):
|
|
|
+ transformation = utils.get_transformation(conf.model)
|
|
|
+ return transformation(conf.image_size, conf.gaussian)
|
|
|
+
|
|
|
+ def post_upload(self):
|
|
|
+ if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
|
|
|
+ del self.model
|
|
|
+ del self.compressed_model
|
|
|
+ self.model = None
|
|
|
+ self.compressed_model = None
|
|
|
+ assert self.model is None
|
|
|
+ assert self.compressed_model is None
|
|
|
+ gc.collect()
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ def info_nce_loss(self, features, n_views=2, temperature=0.07):
|
|
|
+ labels = torch.cat([torch.arange(self.conf.batch_size) for i in range(n_views)], dim=0)
|
|
|
+ labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
|
|
|
+ labels = labels.to(self.device)
|
|
|
+
|
|
|
+ features = F.normalize(features, dim=1)
|
|
|
+
|
|
|
+ similarity_matrix = torch.matmul(features, features.T)
|
|
|
+ # assert similarity_matrix.shape == (
|
|
|
+ # n_views * self.conf.batch_size, n_views * self.conf.batch_size)
|
|
|
+ # assert similarity_matrix.shape == labels.shape
|
|
|
+
|
|
|
+ # discard the main diagonal from both: labels and similarities matrix
|
|
|
+ mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
|
|
|
+ labels = labels[~mask].view(labels.shape[0], -1)
|
|
|
+ similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
|
|
|
+ # assert similarity_matrix.shape == labels.shape
|
|
|
+
|
|
|
+ # select and combine multiple positives
|
|
|
+ positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
|
|
|
+
|
|
|
+ # select only the negatives the negatives
|
|
|
+ negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
|
|
|
+
|
|
|
+ logits = torch.cat([positives, negatives], dim=1)
|
|
|
+ labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
|
|
|
+
|
|
|
+ logits = logits / temperature
|
|
|
+ return logits, labels
|
|
|
+
|
|
|
+
|
|
|
+def compute_lr(current_round, rounds=800, eta_min=0, eta_max=0.3):
|
|
|
+ """Compute learning rate as cosine decay"""
|
|
|
+ pi = np.pi
|
|
|
+ eta_t = eta_min + 0.5 * (eta_max - eta_min) * (np.cos(pi * current_round / rounds) + 1)
|
|
|
+ return eta_t
|
|
|
+
|
|
|
+
|
|
|
+def _print_label_count(cid, labels):
|
|
|
+ logger.info(f"client {cid}: {Counter(labels)}")
|