123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- import copy
- import logging
- import os
- import torch
- import torch.distributed as dist
- from torchvision import datasets
- import model
- import utils
- from communication import TARGET
- from easyfl.datasets.data import CIFAR100
- from easyfl.distributed import reduce_models
- from easyfl.distributed.distributed import CPU
- from easyfl.server import strategies
- from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
- from easyfl.tracking import metric
- from knn_monitor import knn_monitor
- logger = logging.getLogger(__name__)
- class FedSSLServer(BaseServer):
- def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
- super(FedSSLServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
- self.train_loader = None
- self.test_loader = None
- def aggregation(self):
- if self.conf.client.auto_scaler == 'y' and self.conf.server.random_selection:
- self._retain_weight_scaler()
- uploaded_content = self.get_client_uploads()
- models = list(uploaded_content[MODEL].values())
- weights = list(uploaded_content[DATA_SIZE].values())
- # Aggregate networks gradually with different components.
- if self.conf.model in [model.Symmetric, model.SymmetricNoSG, model.SimSiam, model.SimSiamNoSG, model.BYOL,
- model.BYOLNoSG, model.BYOLNoPredictor, model.SimCLR]:
- online_encoders = [m.online_encoder for m in models]
- online_encoder = self._federated_averaging(online_encoders, weights)
- self._model.online_encoder.load_state_dict(online_encoder.state_dict())
- if self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
- predictors = [m.online_predictor for m in models]
- predictor = self._federated_averaging(predictors, weights)
- self._model.online_predictor.load_state_dict(predictor.state_dict())
- if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
- target_encoders = [m.target_encoder for m in models]
- target_encoder = self._federated_averaging(target_encoders, weights)
- self._model.target_encoder = copy.deepcopy(target_encoder)
- if self.conf.model in [model.MoCo, model.MoCoV2]:
- encoder_qs = [m.encoder_q for m in models]
- encoder_q = self._federated_averaging(encoder_qs, weights)
- self._model.encoder_q.load_state_dict(encoder_q.state_dict())
- encoder_ks = [m.encoder_k for m in models]
- encoder_k = self._federated_averaging(encoder_ks, weights)
- self._model.encoder_k.load_state_dict(encoder_k.state_dict())
- def _retain_weight_scaler(self):
- self.client_id_to_index = {c.cid: i for i, c in enumerate(self._clients)}
- client_index = self.client_id_to_index[self.grouped_clients[0].cid]
- weight_scaler = self.grouped_clients[0].weight_scaler if self.grouped_clients[0].weight_scaler else 0
- scaler = torch.tensor((client_index, weight_scaler)).to(self.conf.device)
- scalers = [torch.zeros_like(scaler) for _ in self.selected_clients]
- dist.barrier()
- dist.all_gather(scalers, scaler)
- logger.info(f"Synced scaler {scalers}")
- for i, client in enumerate(self._clients):
- for scaler in scalers:
- scaler = scaler.cpu().numpy()
- if self.client_id_to_index[client.cid] == int(scaler[0]) and not client.weight_scaler:
- self._clients[i].weight_scaler = scaler[1]
- def _federated_averaging(self, models, weights):
- fn_average = strategies.federated_averaging
- fn_sum = strategies.weighted_sum
- fn_reduce = reduce_models
- if self.conf.is_distributed:
- dist.barrier()
- model_, sample_sum = fn_sum(models, weights)
- fn_reduce(model_, torch.tensor(sample_sum).to(self.conf.device))
- else:
- model_ = fn_average(models, weights)
- return model_
- def test_in_server(self, device=CPU):
- testing_model = self._get_testing_model()
- testing_model.eval()
- testing_model.to(device)
- self._get_test_data()
- with torch.no_grad():
- accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader, device=device)
- test_results = {
- metric.TEST_ACCURACY: float(accuracy),
- metric.TEST_LOSS: 0,
- }
- return test_results
- def _get_test_data(self):
- transformation = self._load_transform()
- if self.train_loader is None or self.test_loader is None:
- if self.conf.data.dataset == CIFAR100:
- data_path = "./data/cifar100"
- train_dataset = datasets.CIFAR100(data_path, download=True, transform=transformation)
- test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
- else:
- data_path = "./data/cifar10"
- train_dataset = datasets.CIFAR10(data_path, download=True, transform=transformation)
- test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
- if self.train_loader is None:
- self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, num_workers=8)
- if self.test_loader is None:
- self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, num_workers=8)
- def _load_transform(self):
- transformation = utils.get_transformation(self.conf.model)
- return transformation().test_transform
- def _get_testing_model(self, net=False):
- if self.conf.model in [model.MoCo, model.MoCoV2]:
- testing_model = self._model.encoder_q
- elif self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.Symmetric, model.SymmetricNoSG, model.SimCLR]:
- testing_model = self._model.online_encoder
- else:
- # BYOL
- if self.conf.client.aggregate_encoder == TARGET:
- self.print_("Use aggregated target encoder for testing")
- testing_model = self._model.target_encoder
- else:
- self.print_("Use aggregated online encoder for testing")
- testing_model = self._model.online_encoder
- return testing_model
- def save_model(self):
- if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and self.is_primary_server():
- save_path = self.conf.server.save_model_path
- if save_path == "":
- save_path = os.path.join(os.getcwd(), "saved_models", self.conf.task_id)
- os.makedirs(save_path, exist_ok=True)
- save_path = os.path.join(save_path,
- "{}_global_model_r_{}.pth".format(self.conf.task_id, self._current_round))
- torch.save(self._get_testing_model().cpu().state_dict(), save_path)
- self.print_("Encoder model saved at {}".format(save_path))
- if self.conf.server.save_predictor:
- if self.conf.model in [model.SimSiam, model.BYOL]:
- save_path = save_path.replace("global_model", "predictor")
- torch.save(self._model.online_predictor.cpu().state_dict(), save_path)
- self.print_("Predictor model saved at {}".format(save_path))
|