123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import numpy as np
- from federatedml.linear_model.linear_model_weight import LinearModelWeights
- from federatedml.optim.gradient import hetero_linear_model_gradient
- from federatedml.optim.gradient import sqn_sync
- from federatedml.param.sqn_param import StochasticQuasiNewtonParam
- from federatedml.util import LOGGER
- from federatedml.util import consts
- class HeteroStochasticQuansiNewton(hetero_linear_model_gradient.HeteroGradientBase):
- def __init__(self, sqn_param: StochasticQuasiNewtonParam):
- self.gradient_computer = None
- self.transfer_variable = None
- self.sqn_sync = None
- self.n_iter = 0
- self.count_t = -1
- self.__total_batch_nums = 0
- self.batch_index = 0
- self.last_w_tilde: LinearModelWeights = None
- self.this_w_tilde: LinearModelWeights = None
- # self.sqn_param = sqn_param
- self.update_interval_L = sqn_param.update_interval_L
- self.memory_M = sqn_param.memory_M
- self.sample_size = sqn_param.sample_size
- self.random_seed = sqn_param.random_seed
- self.raise_weight_overflow_error = True
- def unset_raise_weight_overflow_error(self):
- self.raise_weight_overflow_error = False
- @property
- def iter_k(self):
- return self.n_iter * self.__total_batch_nums + self.batch_index + 1
- def set_total_batch_nums(self, total_batch_nums):
- self.__total_batch_nums = total_batch_nums
- def register_gradient_computer(self, gradient_computer):
- self.gradient_computer = copy.deepcopy(gradient_computer)
- def register_transfer_variable(self, transfer_variable):
- self.transfer_variable = transfer_variable
- self.sqn_sync.register_transfer_variable(self.transfer_variable)
- def _renew_w_tilde(self):
- self.last_w_tilde = self.this_w_tilde
- self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
- self.last_w_tilde.fit_intercept,
- raise_overflow_error=self.raise_weight_overflow_error)
- def _update_hessian(self, *args):
- raise NotImplementedError("Should not call here")
- def _update_w_tilde(self, model_weights):
- if self.this_w_tilde is None:
- self.this_w_tilde = copy.deepcopy(model_weights)
- else:
- self.this_w_tilde += model_weights
- def compute_gradient_procedure(self, *args, **kwargs):
- data_instances = args[0]
- cipher = args[1]
- model_weights = args[2]
- optimizer = args[3]
- self.batch_index = args[5]
- self.n_iter = args[4]
- gradient_results = self.gradient_computer.compute_gradient_procedure(*args)
- self._update_w_tilde(model_weights)
- if self.iter_k % self.update_interval_L == 0:
- self.count_t += 1
- # LOGGER.debug("Before division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
- self.this_w_tilde /= self.update_interval_L
- # LOGGER.debug("After division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
- if self.count_t > 0:
- LOGGER.info("iter_k: {}, count_t: {}, start to update hessian".format(self.iter_k, self.count_t))
- self._update_hessian(data_instances, optimizer, cipher)
- self.last_w_tilde = self.this_w_tilde
- self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
- self.last_w_tilde.fit_intercept,
- raise_overflow_error=self.raise_weight_overflow_error)
- # LOGGER.debug("After replace, last_w_tilde: {}, this_w_tilde: {}".format(self.last_w_tilde.unboxed,
- # self.this_w_tilde.unboxed))
- return gradient_results
- def compute_loss(self, *args, **kwargs):
- loss = self.gradient_computer.compute_loss(*args)
- return loss
- class HeteroStochasticQuansiNewtonGuest(HeteroStochasticQuansiNewton):
- def __init__(self, sqn_param):
- super().__init__(sqn_param)
- self.sqn_sync = sqn_sync.Guest()
- def _update_hessian(self, data_instances, optimizer, cipher_operator):
- suffix = (self.n_iter, self.batch_index)
- sampled_data = self.sqn_sync.sync_sample_data(data_instances, self.sample_size, self.random_seed, suffix=suffix)
- delta_s = self.this_w_tilde - self.last_w_tilde
- host_forwards = self.sqn_sync.get_host_forwards(suffix=suffix)
- forward_hess, hess_vector = self.gradient_computer.compute_forward_hess(sampled_data, delta_s, host_forwards)
- self.sqn_sync.remote_forward_hess(forward_hess, suffix)
- hess_norm = optimizer.hess_vector_norm(delta_s)
- # LOGGER.debug("In _update_hessian, hess_norm: {}".format(hess_norm.unboxed))
- hess_vector = hess_vector + hess_norm.unboxed
- self.sqn_sync.sync_hess_vector(hess_vector, suffix)
- class HeteroStochasticQuansiNewtonHost(HeteroStochasticQuansiNewton):
- def __init__(self, sqn_param):
- super().__init__(sqn_param)
- self.sqn_sync = sqn_sync.Host()
- def _update_hessian(self, data_instances, optimizer, cipher_operator):
- suffix = (self.n_iter, self.batch_index)
- sampled_data = self.sqn_sync.sync_sample_data(data_instances, suffix=suffix)
- delta_s = self.this_w_tilde - self.last_w_tilde
- # LOGGER.debug("In _update_hessian, delta_s: {}".format(delta_s.unboxed))
- host_forwards = self.gradient_computer.compute_sqn_forwards(sampled_data, delta_s, cipher_operator)
- # host_forwards = cipher_operator.encrypt_list(host_forwards)
- self.sqn_sync.remote_host_forwards(host_forwards, suffix=suffix)
- forward_hess = self.sqn_sync.get_forward_hess(suffix=suffix)
- hess_vector = self.gradient_computer.compute_forward_hess(sampled_data, delta_s, forward_hess)
- hess_vector += optimizer.hess_vector_norm(delta_s).unboxed
- self.sqn_sync.sync_hess_vector(hess_vector, suffix)
- class HeteroStochasticQuansiNewtonArbiter(HeteroStochasticQuansiNewton):
- def __init__(self, sqn_param):
- super().__init__(sqn_param)
- self.opt_Hess = None
- self.opt_v = None
- self.opt_s = None
- self.sqn_sync = sqn_sync.Arbiter()
- self.model_weight: LinearModelWeights = None
- def _update_w_tilde(self, gradient: LinearModelWeights):
- if self.model_weight is None:
- self.model_weight = copy.deepcopy(gradient)
- else:
- self.model_weight -= gradient
- if self.this_w_tilde is None:
- self.this_w_tilde = copy.deepcopy(self.model_weight)
- else:
- self.this_w_tilde += self.model_weight
- def compute_gradient_procedure(self, cipher_operator, optimizer, n_iter_, batch_index):
- self.batch_index = batch_index
- self.n_iter = n_iter_
- # LOGGER.debug("In compute_gradient_procedure, n_iter: {}, batch_index: {}, iter_k: {}".format(
- # self.n_iter, self.batch_index, self.iter_k
- # ))
- optimizer.set_hess_matrix(self.opt_Hess)
- delta_grad = self.gradient_computer.compute_gradient_procedure(
- cipher_operator, optimizer, n_iter_, batch_index)
- self._update_w_tilde(LinearModelWeights(delta_grad,
- fit_intercept=False,
- raise_overflow_error=self.raise_weight_overflow_error))
- if self.iter_k % self.update_interval_L == 0:
- self.count_t += 1
- # LOGGER.debug("Before division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
- self.this_w_tilde /= self.update_interval_L
- # LOGGER.debug("After division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
- if self.count_t > 0:
- LOGGER.info("iter_k: {}, count_t: {}, start to update hessian".format(self.iter_k, self.count_t))
- self._update_hessian(cipher_operator)
- self.last_w_tilde = self.this_w_tilde
- self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
- self.last_w_tilde.fit_intercept,
- raise_overflow_error=self.raise_weight_overflow_error)
- return delta_grad
- # self._update_w_tilde(cipher_operator)
- def _update_hessian(self, cipher_operator):
- suffix = (self.n_iter, self.batch_index)
- hess_vectors = self.sqn_sync.sync_hess_vector(suffix)
- hess_vectors = np.array(cipher_operator.decrypt_list(hess_vectors))
- delta_s = self.this_w_tilde - self.last_w_tilde
- # LOGGER.debug("In update hessian, hess_vectors: {}, delta_s: {}".format(
- # hess_vectors, delta_s.unboxed
- # ))
- self.opt_v = self._update_memory_vars(hess_vectors, self.opt_v)
- self.opt_s = self._update_memory_vars(delta_s.unboxed, self.opt_s)
- self._compute_hess_matrix()
- def _update_memory_vars(self, new_vars, memory_vars):
- if isinstance(new_vars, list):
- new_vars = np.array(new_vars)
- if memory_vars is None:
- memory_vars = [0, ]
- memory_vars[0] = new_vars.reshape(-1, 1)
- elif len(memory_vars) < self.memory_M:
- memory_vars.append(new_vars.reshape(-1, 1))
- else:
- memory_vars.pop(0)
- memory_vars.append(new_vars.reshape(-1, 1))
- return memory_vars
- def _compute_hess_matrix(self):
- # LOGGER.debug("opt_v: {}, opt_s: {}".format(self.opt_v, self.opt_s))
- rho = sum(self.opt_v[-1] * self.opt_s[-1]) / sum(self.opt_v[-1] * self.opt_v[-1])
- # LOGGER.debug("in _compute_hess_matrix, rho0 = {}".format(rho))
- n = self.opt_s[0].shape[0]
- Hess = rho * np.identity(n)
- iter_num = 0
- for y, s in zip(self.opt_v, self.opt_s):
- rho = 1.0 / (y.T.dot(s))
- Hess = (np.identity(n) - rho * s.dot(y.T)).dot(Hess).dot(np.identity(n) - rho * y.dot(s.T)) + rho * s.dot(
- s.T)
- iter_num += 1
- # LOGGER.info(
- # "hessian updating algorithm iter_num = {}, rho = {} \n ||s|| is {} \n ||y|| is {}".format(iter_num, rho,
- # np.linalg.norm(
- # s),
- # np.linalg.norm(
- # y)))
- self.opt_Hess = Hess
- def sqn_factory(role, sqn_param):
- if role == consts.GUEST:
- return HeteroStochasticQuansiNewtonGuest(sqn_param)
- if role == consts.HOST:
- return HeteroStochasticQuansiNewtonHost(sqn_param)
- return HeteroStochasticQuansiNewtonArbiter(sqn_param)
|