123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- import copy
- import gc
- import logging
- import time
- from collections import Counter
- import numpy as np
- import torch
- import torch._utils
- import torch.nn as nn
- import torch.nn.functional as F
- import model
- import utils
- from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
- from easyfl.client.base import BaseClient
- from easyfl.distributed.distributed import CPU
- from client import FedSSLClient
- logger = logging.getLogger(__name__)
- L2 = "l2"
- def model_dot_product(w1, w2, requires_grad=True):
- """ Return the sum of squared difference between two models. """
- dot_product = 0.0
- for p1, p2 in zip(w1.parameters(), w2.parameters()):
- if requires_grad:
- dot_product += torch.sum(p1 * p2)
- else:
- dot_product += torch.sum(p1.data * p2.data)
- return dot_product
- class FedSSLWithPgFedClient(FedSSLClient):
- def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
- super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
- self._local_model = None
- self.DAPU_predictor = LOCAL
- self.encoder_distance = 1
- self.encoder_distances = []
- self.previous_trained_round = -1
- self.weight_scaler = None
- self.latest_grad = None
- self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe
- self.prev_loss_minuses = {}
- self.prev_mean_grad = None
- self.prev_convex_comb_grad = None
- self.a_i = None
- def train(self, conf, device=CPU):
- start_time = time.time()
- loss_fn, optimizer = self.pretrain_setup(conf, device)
- if conf.model in [model.MoCo, model.MoCoV2]:
- self.model.reset_key_encoder()
- self.train_loss = []
- self.model.to(device)
- old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
- for i in range(conf.local_epoch):
- batch_loss = []
- for (batched_x1, batched_x2), _ in self.train_loader:
- x1, x2 = batched_x1.to(device), batched_x2.to(device)
- optimizer.zero_grad()
- if conf.model in [model.MoCo, model.MoCoV2]:
- loss = self.model(x1, x2, device)
- elif conf.model == model.SimCLR:
- images = torch.cat((x1, x2), dim=0)
- features = self.model(images)
- logits, labels = self.info_nce_loss(features)
- loss = loss_fn(logits, labels)
- else:
- loss = self.model(x1, x2)
- loss.backward()
- if self.prev_convex_comb_grad is not None:
- for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()):
- p_m.grad.data += p_prev_conv.data
- dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False)
- self.update_a_i(dot_prod)
- optimizer.step()
- batch_loss.append(loss.item())
- if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
- self.model.update_moving_average()
- current_epoch_loss = sum(batch_loss) / len(batch_loss)
- self.train_loss.append(float(current_epoch_loss))
- self.loss_minus = 0.0
- test_num = 0
- optimizer.zero_grad()
- for (batched_x1, batched_x2), _ in self.train_loader:
- x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
- test_num += x1.size(0)
- if conf.model in [model.MoCo, model.MoCoV2]:
- loss = self.model(x1, x2, device)
- elif conf.model == model.SimCLR:
- images = torch.cat((x1, x2), dim=0)
- features = self.model(images)
- logits, labels = self.info_nce_loss(features)
- loss = loss_fn(logits, labels)
- else:
- loss = self.model(x1, x2)
- self.loss_minus += loss.item() * x1.size(0)
- self.loss_minus /= test_num
- if not self.latest_grad:
- self.latest_grad = copy.deepcopy(self.model)
-
- # delete later: for test
- # all_grads_none = True
- # for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
- # if p.grad is not None:
- # p_l.data = p.grad.data.clone() / len(self.train_loader)
- # all_grads_none = False
- # else:
- # p_l.data = torch.zeros_like(p_l.data)
- # if all_grads_none:
- # print("All None")
-
- self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
- self.train_time = time.time() - start_time
- # store trained model locally
- # self._local_model = copy.deepcopy(self.model).cpu()
- # self.previous_trained_round = conf.round_id
- # if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
- # new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
- # self.encoder_distance = self._calculate_divergence(old_model, new_model)
- # self.encoder_distances.append(self.encoder_distance.item())
- # self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
- # if self.conf.auto_scaler == 'y' and self.conf.random_selection:
- # self._calculate_weight_scaler()
- # if (conf.round_id + 1) % 100 == 0:
- # logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
- def update_a_i(self, dot_prod):
- for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
- self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod)
- self.a_i[clt_j] = max(self.a_i[clt_j], 0.0)
- def set_prev_mean_grad(self, mean_grad):
- if self.prev_mean_grad is None:
- print("Initing prev_mean_grad")
- self.prev_mean_grad = copy.deepcopy(mean_grad)
- else:
- print("Setting prev_mean_grad")
- self.set_model(self.prev_mean_grad, mean_grad)
- def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
- if self.prev_convex_comb_grad is None:
- print("Initing prev_convex_comb_grad")
- self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
- else:
- print("Setting prev_convex_comb_grad")
- self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
- def set_model(self, old_m, new_m, momentum=0.0):
- for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
- p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()
|