123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- import copy
- import gc
- import logging
- import time
- from collections import Counter
- import numpy as np
- import torch
- import torch._utils
- import torch.nn as nn
- import torch.nn.functional as F
- import model
- import utils
- from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
- from easyfl.client.base import BaseClient
- from easyfl.distributed.distributed import CPU
- logger = logging.getLogger(__name__)
- L2 = "l2"
- class FedSSLClient(BaseClient):
- def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
- super(FedSSLClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
- self._local_model = None
- self.DAPU_predictor = LOCAL
- self.encoder_distance = 1
- self.encoder_distances = []
- self.previous_trained_round = -1
- self.weight_scaler = None
- def decompression(self):
- if self.model is None:
-
- self.model = self.compressed_model
- self.update_model()
- def update_model(self):
- if self.conf.model in [model.MoCo, model.MoCoV2]:
- self.model.encoder_q = self.compressed_model.encoder_q
-
- elif self.conf.model == model.SimCLR:
- self.model.online_encoder = self.compressed_model.online_encoder
- elif self.conf.model in [model.SimSiam, model.SimSiamNoSG]:
- if self._local_model is None:
- self.model.online_encoder = self.compressed_model.online_encoder
- self.model.online_predictor = self.compressed_model.online_predictor
- return
- if self.conf.update_encoder == ONLINE:
- online_encoder = self.compressed_model.online_encoder
- else:
- raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
- f"update {self.conf.update_encoder} is not supported")
- if self.conf.update_predictor == GLOBAL:
- predictor = self.compressed_model.online_predictor
- else:
- raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
- self.model.online_encoder = copy.deepcopy(online_encoder)
- self.model.online_predictor = copy.deepcopy(predictor)
- elif self.conf.model in [model.Symmetric, model.SymmetricNoSG]:
- self.model.online_encoder = self.compressed_model.online_encoder
- elif self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
- if self._local_model is None:
- logger.info("Use aggregated encoder and predictor")
- self.model.online_encoder = self.compressed_model.online_encoder
- self.model.target_encoder = self.compressed_model.online_encoder
- self.model.online_predictor = self.compressed_model.online_predictor
- return
- def ema_online():
- self._calculate_weight_scaler()
- logger.info(f"Encoder: update online with EMA of global encoder @ round {self.conf.round_id}")
- weight = self.encoder_distance
- weight = min(1, self.weight_scaler * weight)
- weight = 1 - weight
- self.compressed_model = self.compressed_model.cpu()
- online_encoder = self.compressed_model.online_encoder
- target_encoder = self._local_model.target_encoder
- ema_updater = model.EMA(weight)
- model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
- return online_encoder, target_encoder
- def ema_predictor():
- logger.info(f"Predictor: use dynamic DAPU")
- distance = self.encoder_distance
- distance = min(1, distance * self.weight_scaler)
- if distance > 0.5:
- weight = distance
- ema_updater = model.EMA(weight)
- predictor = self._local_model.online_predictor
- model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
- else:
- weight = 1 - distance
- ema_updater = model.EMA(weight)
- predictor = self.compressed_model.online_predictor
- model.update_moving_average(ema_updater, predictor, self._local_model.online_predictor)
- return predictor
- if self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == ONLINE:
- logger.info("Encoder: aggregate online, update online")
- online_encoder = self.compressed_model.online_encoder
- target_encoder = self._local_model.target_encoder
- elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == ONLINE:
- logger.info("Encoder: aggregate target, update online")
- online_encoder = self.compressed_model.target_encoder
- target_encoder = self._local_model.target_encoder
- elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == TARGET:
- logger.info("Encoder: aggregate target, update target")
- online_encoder = self._local_model.online_encoder
- target_encoder = self.compressed_model.target_encoder
- elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == TARGET:
- logger.info("Encoder: aggregate online, update target")
- online_encoder = self._local_model.online_encoder
- target_encoder = self.compressed_model.online_encoder
- elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == BOTH:
- logger.info("Encoder: aggregate online, update both")
- online_encoder = self.compressed_model.online_encoder
- target_encoder = self.compressed_model.online_encoder
- elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == BOTH:
- logger.info("Encoder: aggregate target, update both")
- online_encoder = self.compressed_model.target_encoder
- target_encoder = self.compressed_model.target_encoder
- elif self.conf.update_encoder == NONE:
- logger.info("Encoder: use local online and target encoders")
- online_encoder = self._local_model.online_encoder
- target_encoder = self._local_model.target_encoder
- elif self.conf.update_encoder == EMA:
- logger.info(f"Encoder: use EMA, weight {self.conf.encoder_weight}")
- online_encoder = self._local_model.online_encoder
- ema_updater = model.EMA(self.conf.encoder_weight)
- model.update_moving_average(ema_updater, online_encoder, self.compressed_model.online_encoder)
- target_encoder = self._local_model.target_encoder
- elif self.conf.update_encoder == DYNAMIC_EMA_ONLINE:
-
- online_encoder, target_encoder = ema_online()
- elif self.conf.update_encoder == SELECTIVE_EMA:
-
-
- if self.previous_trained_round + 1 == self.conf.round_id:
- online_encoder, target_encoder = ema_online()
- else:
- logger.info(f"Encoder: update online and target @ round {self.conf.round_id}")
- online_encoder = self.compressed_model.online_encoder
- target_encoder = self.compressed_model.online_encoder
- else:
- raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
- f"update {self.conf.update_encoder} is not supported")
- if self.conf.update_predictor == GLOBAL:
- logger.info("Predictor: use global predictor")
- predictor = self.compressed_model.online_predictor
- elif self.conf.update_predictor == LOCAL:
- logger.info("Predictor: use local predictor")
- predictor = self._local_model.online_predictor
- elif self.conf.update_predictor == DAPU:
-
- logger.info(f"Predictor: use DAPU, mu {self.conf.dapu_threshold}")
- if self.DAPU_predictor == GLOBAL:
- predictor = self.compressed_model.online_predictor
- elif self.DAPU_predictor == LOCAL:
- predictor = self._local_model.online_predictor
- else:
- raise ValueError(f"Predictor: DAPU predictor can either use local or global predictor")
- elif self.conf.update_predictor == DYNAMIC_DAPU:
-
- predictor = ema_predictor()
- elif self.conf.update_predictor == SELECTIVE_EMA:
-
- if self.previous_trained_round + 1 == self.conf.round_id:
- predictor = ema_predictor()
- else:
- logger.info("Predictor: use global predictor")
- predictor = self.compressed_model.online_predictor
- elif self.conf.update_predictor == EMA:
- logger.info(f"Predictor: use EMA, weight {self.conf.predictor_weight}")
- predictor = self._local_model.online_predictor
- ema_updater = model.EMA(self.conf.predictor_weight)
- model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
- else:
- raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
- self.model.online_encoder = copy.deepcopy(online_encoder)
- self.model.target_encoder = copy.deepcopy(target_encoder)
- self.model.online_predictor = copy.deepcopy(predictor)
- def train(self, conf, device=CPU):
- start_time = time.time()
- loss_fn, optimizer = self.pretrain_setup(conf, device)
- if conf.model in [model.MoCo, model.MoCoV2]:
- self.model.reset_key_encoder()
- self.train_loss = []
- self.model.to(device)
- old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
- for i in range(conf.local_epoch):
- batch_loss = []
- for (batched_x1, batched_x2), _ in self.train_loader:
- x1, x2 = batched_x1.to(device), batched_x2.to(device)
- optimizer.zero_grad()
- if conf.model in [model.MoCo, model.MoCoV2]:
- loss = self.model(x1, x2, device)
- elif conf.model == model.SimCLR:
- images = torch.cat((x1, x2), dim=0)
- features = self.model(images)
- logits, labels = self.info_nce_loss(features)
- loss = loss_fn(logits, labels)
- else:
- loss = self.model(x1, x2)
- loss.backward()
- optimizer.step()
- batch_loss.append(loss.item())
- if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
- self.model.update_moving_average()
- current_epoch_loss = sum(batch_loss) / len(batch_loss)
- self.train_loss.append(float(current_epoch_loss))
- self.train_time = time.time() - start_time
-
- self._local_model = copy.deepcopy(self.model).cpu()
- self.previous_trained_round = conf.round_id
- if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
- new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
- self.encoder_distance = self._calculate_divergence(old_model, new_model)
- self.encoder_distances.append(self.encoder_distance.item())
- self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
- if self.conf.auto_scaler == 'y' and self.conf.random_selection:
- self._calculate_weight_scaler()
- if (conf.round_id + 1) % 100 == 0:
- logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
- def _DAPU_predictor_usage(self, distance):
- if distance < self.conf.dapu_threshold:
- return GLOBAL
- else:
- return LOCAL
- def _calculate_divergence(self, old_model, new_model, typ=L2):
- size = 0
- total_distance = 0
- old_dict = old_model.state_dict()
- new_dict = new_model.state_dict()
- for name, param in old_model.named_parameters():
- if 'conv' in name and 'weight' in name:
- total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
- new_dict[name].detach().clone().view(1, -1),
- typ)
- size += 1
- if size!=0:
- distance = total_distance / size
- logger.info(f"Model distance: {distance} = {total_distance}/{size}")
- return distance
- else:
- return 0
- def _calculate_distance(self, m1, m2, typ=L2):
- if typ == L2:
- return torch.dist(m1, m2, 2)
- def _calculate_weight_scaler(self):
- if not self.weight_scaler:
- if self.conf.auto_scaler == 'y':
- self.weight_scaler = self.conf.auto_scaler_target / self.encoder_distance
- else:
- self.weight_scaler = self.conf.weight_scaler
- logger.info(f"Client {self.cid}: weight scaler {self.weight_scaler}")
- def load_loader(self, conf):
- drop_last = conf.drop_last
- train_loader = self.train_data.loader(conf.batch_size,
- self.cid,
- shuffle=True,
- drop_last=drop_last,
- seed=conf.seed,
- transform=self._load_transform(conf))
- _print_label_count(self.cid, self.train_data.data[self.cid]['y'])
- return train_loader
- def load_optimizer(self, conf):
- lr = conf.optimizer.lr
- if conf.optimizer.lr_type == "cosine":
- lr = compute_lr(conf.round_id, conf.rounds, 0, conf.optimizer.lr)
-
- if conf.model == model.MoCo:
- lr = conf.optimizer.lr
- params = self.model.parameters()
- if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
- params = [
- {'params': self.model.online_encoder.parameters()},
- {'params': self.model.online_predictor.parameters()}
- ]
- if conf.optimizer.type == "Adam":
- optimizer = torch.optim.Adam(params, lr=lr)
- else:
- optimizer = torch.optim.SGD(params,
- lr=lr,
- momentum=conf.optimizer.momentum,
- weight_decay=conf.optimizer.weight_decay)
- return optimizer
- def _load_transform(self, conf):
- transformation = utils.get_transformation(conf.model)
- return transformation(conf.image_size, conf.gaussian)
- def post_upload(self):
- if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
- del self.model
- del self.compressed_model
- self.model = None
- self.compressed_model = None
- assert self.model is None
- assert self.compressed_model is None
- gc.collect()
- torch.cuda.empty_cache()
- def info_nce_loss(self, features, n_views=2, temperature=0.07):
- labels = torch.cat([torch.arange(self.conf.batch_size) for i in range(n_views)], dim=0)
- labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
- labels = labels.to(self.device)
- features = F.normalize(features, dim=1)
- similarity_matrix = torch.matmul(features, features.T)
-
-
-
-
- mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
- labels = labels[~mask].view(labels.shape[0], -1)
- similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
-
-
- positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
-
- negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
- logits = torch.cat([positives, negatives], dim=1)
- labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
- logits = logits / temperature
- return logits, labels
- def compute_lr(current_round, rounds=800, eta_min=0, eta_max=0.3):
- """Compute learning rate as cosine decay"""
- pi = np.pi
- eta_t = eta_min + 0.5 * (eta_max - eta_min) * (np.cos(pi * current_round / rounds) + 1)
- return eta_t
- def _print_label_count(cid, labels):
- logger.info(f"client {cid}: {Counter(labels)}")
|