import copy import gc import logging import time from collections import Counter import numpy as np import torch import torch._utils import torch.nn as nn import torch.nn.functional as F import model import utils from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA from easyfl.client.base import BaseClient from easyfl.distributed.distributed import CPU from client import FedSSLClient logger = logging.getLogger(__name__) L2 = "l2" def model_dot_product(w1, w2, requires_grad=True): """ Return the sum of squared difference between two models. """ dot_product = 0.0 for p1, p2 in zip(w1.parameters(), w2.parameters()): if requires_grad: dot_product += torch.sum(p1 * p2) else: dot_product += torch.sum(p1.data * p2.data) return dot_product class FedSSLWithPgFedClient(FedSSLClient): def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0): super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time) self._local_model = None self.DAPU_predictor = LOCAL self.encoder_distance = 1 self.encoder_distances = [] self.previous_trained_round = -1 self.weight_scaler = None self.latest_grad = None self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe self.prev_loss_minuses = {} self.prev_mean_grad = None self.prev_convex_comb_grad = None self.a_i = None def train(self, conf, device=CPU): start_time = time.time() loss_fn, optimizer = self.pretrain_setup(conf, device) if conf.model in [model.MoCo, model.MoCoV2]: self.model.reset_key_encoder() self.train_loss = [] self.model.to(device) old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu() for i in range(conf.local_epoch): data_count = 0 # delete later batch_loss = [] for (batched_x1, batched_x2), _ in self.train_loader: if data_count >= 50: break x1, x2 = batched_x1.to(device), batched_x2.to(device) data_count += x1.size(0) optimizer.zero_grad() if conf.model in [model.MoCo, model.MoCoV2]: loss = self.model(x1, x2, device) elif conf.model == model.SimCLR: images = torch.cat((x1, x2), dim=0) features = self.model(images) logits, labels = self.info_nce_loss(features) loss = loss_fn(logits, labels) else: loss = self.model(x1, x2) loss.backward() if self.prev_convex_comb_grad is not None: for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()): p_m.grad.data += p_prev_conv.data dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False) self.update_a_i(dot_prod) optimizer.step() batch_loss.append(loss.item()) if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update: self.model.update_moving_average() current_epoch_loss = sum(batch_loss) / len(batch_loss) self.train_loss.append(float(current_epoch_loss)) self.loss_minus = 0.0 test_num = 0 optimizer.zero_grad() data_count = 0 # delete later for (batched_x1, batched_x2), _ in self.train_loader: if data_count >= 50: break x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device) data_count += x1.size(0) test_num += x1.size(0) if conf.model in [model.MoCo, model.MoCoV2]: loss = self.model(x1, x2, device) elif conf.model == model.SimCLR: images = torch.cat((x1, x2), dim=0) features = self.model(images) logits, labels = self.info_nce_loss(features) loss = loss_fn(logits, labels) else: loss = self.model(x1, x2) self.loss_minus += loss.item() * x1.size(0) self.loss_minus /= test_num if not self.latest_grad: self.latest_grad = copy.deepcopy(self.model) # delete later # all_grads_none = True # for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()): # if p.grad is not None: # p_l.data = p.grad.data.clone() / len(self.train_loader) # all_grads_none = False # else: # p_l.data = torch.zeros_like(p_l.data) # if all_grads_none: # print("All None") self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False) self.train_time = time.time() - start_time # store trained model locally # self._local_model = copy.deepcopy(self.model).cpu() # self.previous_trained_round = conf.round_id # if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]: # new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu() # self.encoder_distance = self._calculate_divergence(old_model, new_model) # self.encoder_distances.append(self.encoder_distance.item()) # self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance) # if self.conf.auto_scaler == 'y' and self.conf.random_selection: # self._calculate_weight_scaler() # if (conf.round_id + 1) % 100 == 0: # logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}") def update_a_i(self, dot_prod): for clt_j, mu_loss_minus in self.prev_loss_minuses.items(): self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod) self.a_i[clt_j] = max(self.a_i[clt_j], 0.0) def set_prev_mean_grad(self, mean_grad): if self.prev_mean_grad is None: print("initing prev_mean_grad") self.prev_mean_grad = copy.deepcopy(mean_grad) else: print("setting prev_mean_grad") self.set_model(self.prev_mean_grad, mean_grad) def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0): if self.prev_convex_comb_grad is None: print("initing prev_convex_comb_grad") self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad) else: print("setting prev_convex_comb_grad") self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum) def set_model(self, old_m, new_m, momentum=0.0): for p_old, p_new in zip(old_m.parameters(), new_m.parameters()): p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()