import copy import logging import os import torch import torch.distributed as dist from torchvision import datasets import time import model import utils from communication import TARGET from easyfl.datasets.data import CIFAR100 from easyfl.distributed import reduce_models from easyfl.distributed.distributed import CPU from easyfl.server import strategies from easyfl.server.base import BaseServer, MODEL, DATA_SIZE from easyfl.tracking import metric from easyfl.protocol import codec from knn_monitor import knn_monitor from server import FedSSLServer logger = logging.getLogger(__name__) class FedSSLWithPgFedServer(FedSSLServer): def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999): super(FedSSLWithPgFedServer, self).__init__(conf, test_data, val_data, is_remote, local_port) self.train_loader = None self.test_loader = None self.mu = 0 self.momentum = 0.0 self.alpha_mat = None self.uploaded_grads = {} self.loss_minuses = {} self.mean_grad = None self.convex_comb_grad = None def set_clients(self, clients): self._clients = clients for i, _ in enumerate(self._clients): self._clients[i].id = i def train(self): """Training process of federated learning.""" self.print_("--- start training ---") print(f"\nJoin clients / total clients: {self.conf.server.clients_per_round} / {len(self._clients)}") self.selection(self._clients, self.conf.server.clients_per_round) self.grouping_for_distributed() self.compression() begin_train_time = time.time() self.send_param() self.distribution_to_train() self.aggregation() train_time = time.time() - begin_train_time self.print_("Server train time: {}".format(train_time)) self.track(metric.TRAIN_TIME, train_time) def send_param(self): if not self.alpha_mat: self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device) for client in self.grouped_clients: client.a_i = self.alpha_mat[client.id] if len(self.uploaded_grads) == 0: return self.convex_comb_grad = copy.deepcopy(list(self.uploaded_grads.values())[0]) for client in self.grouped_clients: client.set_prev_mean_grad(self.mean_grad) mu_a_i = self.alpha_mat[client.id] * self.mu grads, weights = [], [] for clt_idx, grad in self.uploaded_grads.items(): weights.append(mu_a_i[clt_idx]) grads.append(grad) self.model_weighted_sum(self.convex_comb_grad, grads, weights) client.set_prev_convex_comb_grad(self.convex_comb_grad, momentum=self.momentum) client.prev_loss_minuses = copy.deepcopy(self.loss_minuses) def distribution_to_train_locally(self): """Conduct training sequentially for selected clients in the group.""" uploaded_models = {} uploaded_weights = {} uploaded_metrics = [] for client in self.grouped_clients: # Update client config before training self.conf.client.task_id = self.conf.task_id self.conf.client.round_id = self._current_round uploaded_request = client.run_train(self._compressed_model, self.conf.client) uploaded_content = uploaded_request.content model = self.decompression(codec.unmarshal(uploaded_content.data)) uploaded_models[client.cid] = model uploaded_weights[client.cid] = uploaded_content.data_size uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric)) self.receive_param() self.set_client_uploads_train(uploaded_models, uploaded_weights, uploaded_metrics) def receive_param(self): self.uploaded_ids = [] self.uploaded_grads = {} self.loss_minuses = {} for client in self.selected_clients: self.uploaded_ids.append(client.id) self.alpha_mat[client.id] = client.a_i self.uploaded_grads[client.id] = client.latest_grad self.loss_minuses[client.id] = client.loss_minus * self.mu def get_mean_grad(self): w = self.mu/self.conf.server.clients_per_round weights = [w for _ in range(self.conf.server.clients_per_round)] self.mean_grad = copy.deepcopy(list(self.uploaded_grads.values())[0]) self.model_weighted_sum(self.mean_grad, list(self.uploaded_grads.values()), weights) def model_weighted_sum(self, model, models, weights): for p_m in model.parameters(): p_m.data.zero_() for w, m_i in zip(weights, models): for p_m, p_i in zip(model.parameters(), m_i.parameters()): p_m.data += p_i.data.clone() * w