import copy
import logging
import os

import torch
import torch.distributed as dist
from torchvision import datasets

import model
import utils
from communication import TARGET
from easyfl.datasets.data import CIFAR100
from easyfl.distributed import reduce_models
from easyfl.distributed.distributed import CPU
from easyfl.server import strategies
from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
from easyfl.tracking import metric
from knn_monitor import knn_monitor

logger = logging.getLogger(__name__)


class FedSSLServer(BaseServer):
    def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
        super(FedSSLServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
        self.train_loader = None
        self.test_loader = None

    def aggregation(self):
        if self.conf.client.auto_scaler == 'y' and self.conf.server.random_selection:
            self._retain_weight_scaler()

        uploaded_content = self.get_client_uploads()
        models = list(uploaded_content[MODEL].values())
        weights = list(uploaded_content[DATA_SIZE].values())

        # Aggregate networks gradually with different components.
        if self.conf.model in [model.Symmetric, model.SymmetricNoSG, model.SimSiam, model.SimSiamNoSG, model.BYOL,
                               model.BYOLNoSG, model.BYOLNoPredictor, model.SimCLR]:
            online_encoders = [m.online_encoder for m in models]
            online_encoder = self._federated_averaging(online_encoders, weights)
            self._model.online_encoder.load_state_dict(online_encoder.state_dict())

        if self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
            predictors = [m.online_predictor for m in models]
            predictor = self._federated_averaging(predictors, weights)
            self._model.online_predictor.load_state_dict(predictor.state_dict())

        if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
            target_encoders = [m.target_encoder for m in models]
            target_encoder = self._federated_averaging(target_encoders, weights)
            self._model.target_encoder = copy.deepcopy(target_encoder)

        if self.conf.model in [model.MoCo, model.MoCoV2]:
            encoder_qs = [m.encoder_q for m in models]
            encoder_q = self._federated_averaging(encoder_qs, weights)
            self._model.encoder_q.load_state_dict(encoder_q.state_dict())

            encoder_ks = [m.encoder_k for m in models]
            encoder_k = self._federated_averaging(encoder_ks, weights)
            self._model.encoder_k.load_state_dict(encoder_k.state_dict())

    def _retain_weight_scaler(self):
        self.client_id_to_index = {c.cid: i for i, c in enumerate(self._clients)}

        client_index = self.client_id_to_index[self.grouped_clients[0].cid]
        weight_scaler = self.grouped_clients[0].weight_scaler if self.grouped_clients[0].weight_scaler else 0
        scaler = torch.tensor((client_index, weight_scaler)).to(self.conf.device)
        scalers = [torch.zeros_like(scaler) for _ in self.selected_clients]
        dist.barrier()
        dist.all_gather(scalers, scaler)

        logger.info(f"Synced scaler {scalers}")
        for i, client in enumerate(self._clients):
            for scaler in scalers:
                scaler = scaler.cpu().numpy()
                if self.client_id_to_index[client.cid] == int(scaler[0]) and not client.weight_scaler:
                    self._clients[i].weight_scaler = scaler[1]

    def _federated_averaging(self, models, weights):
        fn_average = strategies.federated_averaging
        fn_sum = strategies.weighted_sum
        fn_reduce = reduce_models

        if self.conf.is_distributed:
            dist.barrier()
            model_, sample_sum = fn_sum(models, weights)
            fn_reduce(model_, torch.tensor(sample_sum).to(self.conf.device))
        else:
            model_ = fn_average(models, weights)
        return model_

    def test_in_server(self, device=CPU):
        testing_model = self._get_testing_model()
        testing_model.eval()
        testing_model.to(device)

        self._get_test_data()

        with torch.no_grad():
            accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader)

        test_results = {
            metric.TEST_ACCURACY: float(accuracy),
            metric.TEST_LOSS: 0,
        }
        return test_results

    def _get_test_data(self):
        transformation = self._load_transform()
        if self.train_loader is None or self.test_loader is None:
            if self.conf.data.dataset == CIFAR100:
                data_path = "./data/cifar100"
                train_dataset = datasets.CIFAR100(data_path, download=True, transform=transformation)
                test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
            else:
                data_path = "./data/cifar10"
                train_dataset = datasets.CIFAR10(data_path, download=True, transform=transformation)
                test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)

            if self.train_loader is None:
                self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, num_workers=8)

            if self.test_loader is None:
                self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, num_workers=8)

    def _load_transform(self):
        transformation = utils.get_transformation(self.conf.model)
        return transformation().test_transform

    def _get_testing_model(self, net=False):
        if self.conf.model in [model.MoCo, model.MoCoV2]:
            testing_model = self._model.encoder_q
        elif self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.Symmetric, model.SymmetricNoSG, model.SimCLR]:
            testing_model = self._model.online_encoder
        else:
            # BYOL
            if self.conf.client.aggregate_encoder == TARGET:
                self.print_("Use aggregated target encoder for testing")
                testing_model = self._model.target_encoder
            else:
                self.print_("Use aggregated online encoder for testing")
                testing_model = self._model.online_encoder
        return testing_model

    def save_model(self):
        if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and self.is_primary_server():
            save_path = self.conf.server.save_model_path
            if save_path == "":
                save_path = os.path.join(os.getcwd(), "saved_models", self.conf.task_id)
            os.makedirs(save_path, exist_ok=True)
            save_path = os.path.join(save_path,
                                     "{}_global_model_r_{}.pth".format(self.conf.task_id, self._current_round))

            torch.save(self._get_testing_model().cpu().state_dict(), save_path)
            self.print_("Encoder model saved at {}".format(save_path))

            if self.conf.server.save_predictor:
                if self.conf.model in [model.SimSiam, model.BYOL]:
                    save_path = save_path.replace("global_model", "predictor")
                    torch.save(self._model.online_predictor.cpu().state_dict(), save_path)
                    self.print_("Predictor model saved at {}".format(save_path))