123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- import copy
- import logging
- import os
- import torch
- import torch.distributed as dist
- from torchvision import datasets
- import time
- import model
- import utils
- from communication import TARGET
- from easyfl.datasets.data import CIFAR100
- from easyfl.distributed import reduce_models
- from easyfl.distributed.distributed import CPU
- from easyfl.server import strategies
- from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
- from easyfl.tracking import metric
- from easyfl.protocol import codec
- from knn_monitor import knn_monitor
- from server import FedSSLServer
- logger = logging.getLogger(__name__)
- class FedSSLWithPgFedServer(FedSSLServer):
- def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
- super(FedSSLWithPgFedServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
- self.train_loader = None
- self.test_loader = None
- self.mu = 0
- self.momentum = 0.0
- self.alpha_mat = None
- self.uploaded_grads = {}
- self.loss_minuses = {}
- self.mean_grad = None
- self.convex_comb_grad = None
- def set_clients(self, clients):
- self._clients = clients
- for i, _ in enumerate(self._clients):
- self._clients[i].id = i
- def train(self):
- """Training process of federated learning."""
- self.print_("--- start training ---")
- print(f"\nJoin clients / total clients: {self.conf.server.clients_per_round} / {len(self._clients)}")
- self.selection(self._clients, self.conf.server.clients_per_round)
- self.grouping_for_distributed()
- self.compression()
- begin_train_time = time.time()
- self.send_param()
- self.distribution_to_train()
- self.aggregation()
- self.get_mean_grad()
- train_time = time.time() - begin_train_time
- self.print_("Server train time: {}".format(train_time))
- self.track(metric.TRAIN_TIME, train_time)
-
- def send_param(self):
- if self.alpha_mat==None:
- self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
- for client in self.grouped_clients:
- client.a_i = self.alpha_mat[client.id]
- if len(self.uploaded_grads) == 0:
- return
- self.convex_comb_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
- for client in self.grouped_clients:
- client.set_prev_mean_grad(self.mean_grad)
- mu_a_i = self.alpha_mat[client.id] * self.mu
- grads, weights = [], []
- for clt_idx, grad in self.uploaded_grads.items():
- weights.append(mu_a_i[clt_idx])
- grads.append(grad)
- self.model_weighted_sum(self.convex_comb_grad, grads, weights)
- client.set_prev_convex_comb_grad(self.convex_comb_grad, momentum=self.momentum)
- client.prev_loss_minuses = copy.deepcopy(self.loss_minuses)
-
- def distribution_to_train_locally(self):
- """Conduct training sequentially for selected clients in the group."""
- uploaded_models = {}
- uploaded_weights = {}
- uploaded_metrics = []
- for client in self.grouped_clients:
- # Update client config before training
- self.conf.client.task_id = self.conf.task_id
- self.conf.client.round_id = self._current_round
- uploaded_request = client.run_train(self._compressed_model, self.conf.client)
- uploaded_content = uploaded_request.content
- model = self.decompression(codec.unmarshal(uploaded_content.data))
- uploaded_models[client.cid] = model
- uploaded_weights[client.cid] = uploaded_content.data_size
-
- uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))
- self.receive_param()
- self.set_client_uploads_train(uploaded_models, uploaded_weights, uploaded_metrics)
- def receive_param(self):
- self.uploaded_ids = []
- self.uploaded_grads = {}
- self.loss_minuses = {}
- for client in self.selected_clients:
- self.uploaded_ids.append(client.id)
- self.alpha_mat[client.id] = client.a_i
- self.uploaded_grads[client.id] = client.latest_grad
- print("client.loss_minus:",client.loss_minus)
- self.loss_minuses[client.id] = client.loss_minus
-
- def get_mean_grad(self):
- w = self.mu/self.conf.server.clients_per_round
- weights = [w for _ in range(self.conf.server.clients_per_round)]
- self.mean_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
- self.model_weighted_sum(self.mean_grad, list(self.uploaded_grads.values()), weights)
- def model_weighted_sum(self, model, models, weights):
- for p_m in model.parameters():
- p_m.data.zero_()
- for w, m_i in zip(weights, models):
- for p_m, p_i in zip(model.parameters(), m_i.parameters()):
- p_m.data += p_i.data.clone() * w
|