#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import copy

import numpy as np

from federatedml.linear_model.linear_model_weight import LinearModelWeights
from federatedml.optim.gradient import hetero_linear_model_gradient
from federatedml.optim.gradient import sqn_sync
from federatedml.param.sqn_param import StochasticQuasiNewtonParam
from federatedml.util import LOGGER
from federatedml.util import consts


class HeteroStochasticQuansiNewton(hetero_linear_model_gradient.HeteroGradientBase):
    def __init__(self, sqn_param: StochasticQuasiNewtonParam):
        self.gradient_computer = None
        self.transfer_variable = None
        self.sqn_sync = None
        self.n_iter = 0
        self.count_t = -1
        self.__total_batch_nums = 0
        self.batch_index = 0
        self.last_w_tilde: LinearModelWeights = None
        self.this_w_tilde: LinearModelWeights = None
        # self.sqn_param = sqn_param
        self.update_interval_L = sqn_param.update_interval_L
        self.memory_M = sqn_param.memory_M
        self.sample_size = sqn_param.sample_size
        self.random_seed = sqn_param.random_seed
        self.raise_weight_overflow_error = True

    def unset_raise_weight_overflow_error(self):
        self.raise_weight_overflow_error = False

    @property
    def iter_k(self):
        return self.n_iter * self.__total_batch_nums + self.batch_index + 1

    def set_total_batch_nums(self, total_batch_nums):
        self.__total_batch_nums = total_batch_nums

    def register_gradient_computer(self, gradient_computer):
        self.gradient_computer = copy.deepcopy(gradient_computer)

    def register_transfer_variable(self, transfer_variable):
        self.transfer_variable = transfer_variable
        self.sqn_sync.register_transfer_variable(self.transfer_variable)

    def _renew_w_tilde(self):
        self.last_w_tilde = self.this_w_tilde
        self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
                                               self.last_w_tilde.fit_intercept,
                                               raise_overflow_error=self.raise_weight_overflow_error)

    def _update_hessian(self, *args):
        raise NotImplementedError("Should not call here")

    def _update_w_tilde(self, model_weights):
        if self.this_w_tilde is None:
            self.this_w_tilde = copy.deepcopy(model_weights)
        else:
            self.this_w_tilde += model_weights

    def compute_gradient_procedure(self, *args, **kwargs):
        data_instances = args[0]
        cipher = args[1]
        model_weights = args[2]
        optimizer = args[3]
        self.batch_index = args[5]
        self.n_iter = args[4]

        gradient_results = self.gradient_computer.compute_gradient_procedure(*args)
        self._update_w_tilde(model_weights)

        if self.iter_k % self.update_interval_L == 0:
            self.count_t += 1
            # LOGGER.debug("Before division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
            self.this_w_tilde /= self.update_interval_L
            # LOGGER.debug("After division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))

            if self.count_t > 0:
                LOGGER.info("iter_k: {}, count_t: {}, start to update hessian".format(self.iter_k, self.count_t))
                self._update_hessian(data_instances, optimizer, cipher)
            self.last_w_tilde = self.this_w_tilde
            self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
                                                   self.last_w_tilde.fit_intercept,
                                                   raise_overflow_error=self.raise_weight_overflow_error)
            # LOGGER.debug("After replace, last_w_tilde: {}, this_w_tilde: {}".format(self.last_w_tilde.unboxed,
            #                                                                         self.this_w_tilde.unboxed))

        return gradient_results

    def compute_loss(self, *args, **kwargs):
        loss = self.gradient_computer.compute_loss(*args)
        return loss


class HeteroStochasticQuansiNewtonGuest(HeteroStochasticQuansiNewton):
    def __init__(self, sqn_param):
        super().__init__(sqn_param)
        self.sqn_sync = sqn_sync.Guest()

    def _update_hessian(self, data_instances, optimizer, cipher_operator):
        suffix = (self.n_iter, self.batch_index)

        sampled_data = self.sqn_sync.sync_sample_data(data_instances, self.sample_size, self.random_seed, suffix=suffix)
        delta_s = self.this_w_tilde - self.last_w_tilde
        host_forwards = self.sqn_sync.get_host_forwards(suffix=suffix)
        forward_hess, hess_vector = self.gradient_computer.compute_forward_hess(sampled_data, delta_s, host_forwards)
        self.sqn_sync.remote_forward_hess(forward_hess, suffix)
        hess_norm = optimizer.hess_vector_norm(delta_s)
        # LOGGER.debug("In _update_hessian, hess_norm: {}".format(hess_norm.unboxed))
        hess_vector = hess_vector + hess_norm.unboxed
        self.sqn_sync.sync_hess_vector(hess_vector, suffix)


class HeteroStochasticQuansiNewtonHost(HeteroStochasticQuansiNewton):
    def __init__(self, sqn_param):
        super().__init__(sqn_param)
        self.sqn_sync = sqn_sync.Host()

    def _update_hessian(self, data_instances, optimizer, cipher_operator):
        suffix = (self.n_iter, self.batch_index)
        sampled_data = self.sqn_sync.sync_sample_data(data_instances, suffix=suffix)
        delta_s = self.this_w_tilde - self.last_w_tilde
        # LOGGER.debug("In _update_hessian, delta_s: {}".format(delta_s.unboxed))
        host_forwards = self.gradient_computer.compute_sqn_forwards(sampled_data, delta_s, cipher_operator)
        # host_forwards = cipher_operator.encrypt_list(host_forwards)
        self.sqn_sync.remote_host_forwards(host_forwards, suffix=suffix)
        forward_hess = self.sqn_sync.get_forward_hess(suffix=suffix)
        hess_vector = self.gradient_computer.compute_forward_hess(sampled_data, delta_s, forward_hess)
        hess_vector += optimizer.hess_vector_norm(delta_s).unboxed
        self.sqn_sync.sync_hess_vector(hess_vector, suffix)


class HeteroStochasticQuansiNewtonArbiter(HeteroStochasticQuansiNewton):
    def __init__(self, sqn_param):
        super().__init__(sqn_param)
        self.opt_Hess = None
        self.opt_v = None
        self.opt_s = None
        self.sqn_sync = sqn_sync.Arbiter()
        self.model_weight: LinearModelWeights = None

    def _update_w_tilde(self, gradient: LinearModelWeights):
        if self.model_weight is None:
            self.model_weight = copy.deepcopy(gradient)
        else:
            self.model_weight -= gradient

        if self.this_w_tilde is None:
            self.this_w_tilde = copy.deepcopy(self.model_weight)
        else:
            self.this_w_tilde += self.model_weight

    def compute_gradient_procedure(self, cipher_operator, optimizer, n_iter_, batch_index):
        self.batch_index = batch_index
        self.n_iter = n_iter_
        # LOGGER.debug("In compute_gradient_procedure, n_iter: {}, batch_index: {}, iter_k: {}".format(
        #     self.n_iter, self.batch_index, self.iter_k
        # ))

        optimizer.set_hess_matrix(self.opt_Hess)
        delta_grad = self.gradient_computer.compute_gradient_procedure(
            cipher_operator, optimizer, n_iter_, batch_index)
        self._update_w_tilde(LinearModelWeights(delta_grad,
                                                fit_intercept=False,
                                                raise_overflow_error=self.raise_weight_overflow_error))
        if self.iter_k % self.update_interval_L == 0:
            self.count_t += 1
            # LOGGER.debug("Before division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
            self.this_w_tilde /= self.update_interval_L
            # LOGGER.debug("After division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))

            if self.count_t > 0:
                LOGGER.info("iter_k: {}, count_t: {}, start to update hessian".format(self.iter_k, self.count_t))
                self._update_hessian(cipher_operator)
            self.last_w_tilde = self.this_w_tilde
            self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
                                                   self.last_w_tilde.fit_intercept,
                                                   raise_overflow_error=self.raise_weight_overflow_error)
        return delta_grad

        # self._update_w_tilde(cipher_operator)

    def _update_hessian(self, cipher_operator):
        suffix = (self.n_iter, self.batch_index)
        hess_vectors = self.sqn_sync.sync_hess_vector(suffix)
        hess_vectors = np.array(cipher_operator.decrypt_list(hess_vectors))
        delta_s = self.this_w_tilde - self.last_w_tilde
        # LOGGER.debug("In update hessian, hess_vectors: {}, delta_s: {}".format(
        #     hess_vectors, delta_s.unboxed
        # ))
        self.opt_v = self._update_memory_vars(hess_vectors, self.opt_v)
        self.opt_s = self._update_memory_vars(delta_s.unboxed, self.opt_s)
        self._compute_hess_matrix()

    def _update_memory_vars(self, new_vars, memory_vars):
        if isinstance(new_vars, list):
            new_vars = np.array(new_vars)
        if memory_vars is None:
            memory_vars = [0, ]
            memory_vars[0] = new_vars.reshape(-1, 1)
        elif len(memory_vars) < self.memory_M:
            memory_vars.append(new_vars.reshape(-1, 1))
        else:
            memory_vars.pop(0)
            memory_vars.append(new_vars.reshape(-1, 1))
        return memory_vars

    def _compute_hess_matrix(self):
        # LOGGER.debug("opt_v: {}, opt_s: {}".format(self.opt_v, self.opt_s))
        rho = sum(self.opt_v[-1] * self.opt_s[-1]) / sum(self.opt_v[-1] * self.opt_v[-1])
        # LOGGER.debug("in _compute_hess_matrix, rho0 = {}".format(rho))
        n = self.opt_s[0].shape[0]
        Hess = rho * np.identity(n)
        iter_num = 0
        for y, s in zip(self.opt_v, self.opt_s):
            rho = 1.0 / (y.T.dot(s))
            Hess = (np.identity(n) - rho * s.dot(y.T)).dot(Hess).dot(np.identity(n) - rho * y.dot(s.T)) + rho * s.dot(
                s.T)
            iter_num += 1
            # LOGGER.info(
            #     "hessian updating algorithm iter_num = {}, rho = {} \n ||s|| is {} \n ||y|| is {}".format(iter_num, rho,
            #                                                                                               np.linalg.norm(
            #                                                                                                   s),
            #                                                                                               np.linalg.norm(
            #                                                                                                   y)))

        self.opt_Hess = Hess


def sqn_factory(role, sqn_param):
    if role == consts.GUEST:
        return HeteroStochasticQuansiNewtonGuest(sqn_param)

    if role == consts.HOST:
        return HeteroStochasticQuansiNewtonHost(sqn_param)

    return HeteroStochasticQuansiNewtonArbiter(sqn_param)