hetero_sqn_gradient.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import numpy as np
  19. from federatedml.linear_model.linear_model_weight import LinearModelWeights
  20. from federatedml.optim.gradient import hetero_linear_model_gradient
  21. from federatedml.optim.gradient import sqn_sync
  22. from federatedml.param.sqn_param import StochasticQuasiNewtonParam
  23. from federatedml.util import LOGGER
  24. from federatedml.util import consts
  25. class HeteroStochasticQuansiNewton(hetero_linear_model_gradient.HeteroGradientBase):
  26. def __init__(self, sqn_param: StochasticQuasiNewtonParam):
  27. self.gradient_computer = None
  28. self.transfer_variable = None
  29. self.sqn_sync = None
  30. self.n_iter = 0
  31. self.count_t = -1
  32. self.__total_batch_nums = 0
  33. self.batch_index = 0
  34. self.last_w_tilde: LinearModelWeights = None
  35. self.this_w_tilde: LinearModelWeights = None
  36. # self.sqn_param = sqn_param
  37. self.update_interval_L = sqn_param.update_interval_L
  38. self.memory_M = sqn_param.memory_M
  39. self.sample_size = sqn_param.sample_size
  40. self.random_seed = sqn_param.random_seed
  41. self.raise_weight_overflow_error = True
  42. def unset_raise_weight_overflow_error(self):
  43. self.raise_weight_overflow_error = False
  44. @property
  45. def iter_k(self):
  46. return self.n_iter * self.__total_batch_nums + self.batch_index + 1
  47. def set_total_batch_nums(self, total_batch_nums):
  48. self.__total_batch_nums = total_batch_nums
  49. def register_gradient_computer(self, gradient_computer):
  50. self.gradient_computer = copy.deepcopy(gradient_computer)
  51. def register_transfer_variable(self, transfer_variable):
  52. self.transfer_variable = transfer_variable
  53. self.sqn_sync.register_transfer_variable(self.transfer_variable)
  54. def _renew_w_tilde(self):
  55. self.last_w_tilde = self.this_w_tilde
  56. self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
  57. self.last_w_tilde.fit_intercept,
  58. raise_overflow_error=self.raise_weight_overflow_error)
  59. def _update_hessian(self, *args):
  60. raise NotImplementedError("Should not call here")
  61. def _update_w_tilde(self, model_weights):
  62. if self.this_w_tilde is None:
  63. self.this_w_tilde = copy.deepcopy(model_weights)
  64. else:
  65. self.this_w_tilde += model_weights
  66. def compute_gradient_procedure(self, *args, **kwargs):
  67. data_instances = args[0]
  68. cipher = args[1]
  69. model_weights = args[2]
  70. optimizer = args[3]
  71. self.batch_index = args[5]
  72. self.n_iter = args[4]
  73. gradient_results = self.gradient_computer.compute_gradient_procedure(*args)
  74. self._update_w_tilde(model_weights)
  75. if self.iter_k % self.update_interval_L == 0:
  76. self.count_t += 1
  77. # LOGGER.debug("Before division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
  78. self.this_w_tilde /= self.update_interval_L
  79. # LOGGER.debug("After division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
  80. if self.count_t > 0:
  81. LOGGER.info("iter_k: {}, count_t: {}, start to update hessian".format(self.iter_k, self.count_t))
  82. self._update_hessian(data_instances, optimizer, cipher)
  83. self.last_w_tilde = self.this_w_tilde
  84. self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
  85. self.last_w_tilde.fit_intercept,
  86. raise_overflow_error=self.raise_weight_overflow_error)
  87. # LOGGER.debug("After replace, last_w_tilde: {}, this_w_tilde: {}".format(self.last_w_tilde.unboxed,
  88. # self.this_w_tilde.unboxed))
  89. return gradient_results
  90. def compute_loss(self, *args, **kwargs):
  91. loss = self.gradient_computer.compute_loss(*args)
  92. return loss
  93. class HeteroStochasticQuansiNewtonGuest(HeteroStochasticQuansiNewton):
  94. def __init__(self, sqn_param):
  95. super().__init__(sqn_param)
  96. self.sqn_sync = sqn_sync.Guest()
  97. def _update_hessian(self, data_instances, optimizer, cipher_operator):
  98. suffix = (self.n_iter, self.batch_index)
  99. sampled_data = self.sqn_sync.sync_sample_data(data_instances, self.sample_size, self.random_seed, suffix=suffix)
  100. delta_s = self.this_w_tilde - self.last_w_tilde
  101. host_forwards = self.sqn_sync.get_host_forwards(suffix=suffix)
  102. forward_hess, hess_vector = self.gradient_computer.compute_forward_hess(sampled_data, delta_s, host_forwards)
  103. self.sqn_sync.remote_forward_hess(forward_hess, suffix)
  104. hess_norm = optimizer.hess_vector_norm(delta_s)
  105. # LOGGER.debug("In _update_hessian, hess_norm: {}".format(hess_norm.unboxed))
  106. hess_vector = hess_vector + hess_norm.unboxed
  107. self.sqn_sync.sync_hess_vector(hess_vector, suffix)
  108. class HeteroStochasticQuansiNewtonHost(HeteroStochasticQuansiNewton):
  109. def __init__(self, sqn_param):
  110. super().__init__(sqn_param)
  111. self.sqn_sync = sqn_sync.Host()
  112. def _update_hessian(self, data_instances, optimizer, cipher_operator):
  113. suffix = (self.n_iter, self.batch_index)
  114. sampled_data = self.sqn_sync.sync_sample_data(data_instances, suffix=suffix)
  115. delta_s = self.this_w_tilde - self.last_w_tilde
  116. # LOGGER.debug("In _update_hessian, delta_s: {}".format(delta_s.unboxed))
  117. host_forwards = self.gradient_computer.compute_sqn_forwards(sampled_data, delta_s, cipher_operator)
  118. # host_forwards = cipher_operator.encrypt_list(host_forwards)
  119. self.sqn_sync.remote_host_forwards(host_forwards, suffix=suffix)
  120. forward_hess = self.sqn_sync.get_forward_hess(suffix=suffix)
  121. hess_vector = self.gradient_computer.compute_forward_hess(sampled_data, delta_s, forward_hess)
  122. hess_vector += optimizer.hess_vector_norm(delta_s).unboxed
  123. self.sqn_sync.sync_hess_vector(hess_vector, suffix)
  124. class HeteroStochasticQuansiNewtonArbiter(HeteroStochasticQuansiNewton):
  125. def __init__(self, sqn_param):
  126. super().__init__(sqn_param)
  127. self.opt_Hess = None
  128. self.opt_v = None
  129. self.opt_s = None
  130. self.sqn_sync = sqn_sync.Arbiter()
  131. self.model_weight: LinearModelWeights = None
  132. def _update_w_tilde(self, gradient: LinearModelWeights):
  133. if self.model_weight is None:
  134. self.model_weight = copy.deepcopy(gradient)
  135. else:
  136. self.model_weight -= gradient
  137. if self.this_w_tilde is None:
  138. self.this_w_tilde = copy.deepcopy(self.model_weight)
  139. else:
  140. self.this_w_tilde += self.model_weight
  141. def compute_gradient_procedure(self, cipher_operator, optimizer, n_iter_, batch_index):
  142. self.batch_index = batch_index
  143. self.n_iter = n_iter_
  144. # LOGGER.debug("In compute_gradient_procedure, n_iter: {}, batch_index: {}, iter_k: {}".format(
  145. # self.n_iter, self.batch_index, self.iter_k
  146. # ))
  147. optimizer.set_hess_matrix(self.opt_Hess)
  148. delta_grad = self.gradient_computer.compute_gradient_procedure(
  149. cipher_operator, optimizer, n_iter_, batch_index)
  150. self._update_w_tilde(LinearModelWeights(delta_grad,
  151. fit_intercept=False,
  152. raise_overflow_error=self.raise_weight_overflow_error))
  153. if self.iter_k % self.update_interval_L == 0:
  154. self.count_t += 1
  155. # LOGGER.debug("Before division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
  156. self.this_w_tilde /= self.update_interval_L
  157. # LOGGER.debug("After division, this_w_tilde: {}".format(self.this_w_tilde.unboxed))
  158. if self.count_t > 0:
  159. LOGGER.info("iter_k: {}, count_t: {}, start to update hessian".format(self.iter_k, self.count_t))
  160. self._update_hessian(cipher_operator)
  161. self.last_w_tilde = self.this_w_tilde
  162. self.this_w_tilde = LinearModelWeights(np.zeros_like(self.last_w_tilde.unboxed),
  163. self.last_w_tilde.fit_intercept,
  164. raise_overflow_error=self.raise_weight_overflow_error)
  165. return delta_grad
  166. # self._update_w_tilde(cipher_operator)
  167. def _update_hessian(self, cipher_operator):
  168. suffix = (self.n_iter, self.batch_index)
  169. hess_vectors = self.sqn_sync.sync_hess_vector(suffix)
  170. hess_vectors = np.array(cipher_operator.decrypt_list(hess_vectors))
  171. delta_s = self.this_w_tilde - self.last_w_tilde
  172. # LOGGER.debug("In update hessian, hess_vectors: {}, delta_s: {}".format(
  173. # hess_vectors, delta_s.unboxed
  174. # ))
  175. self.opt_v = self._update_memory_vars(hess_vectors, self.opt_v)
  176. self.opt_s = self._update_memory_vars(delta_s.unboxed, self.opt_s)
  177. self._compute_hess_matrix()
  178. def _update_memory_vars(self, new_vars, memory_vars):
  179. if isinstance(new_vars, list):
  180. new_vars = np.array(new_vars)
  181. if memory_vars is None:
  182. memory_vars = [0, ]
  183. memory_vars[0] = new_vars.reshape(-1, 1)
  184. elif len(memory_vars) < self.memory_M:
  185. memory_vars.append(new_vars.reshape(-1, 1))
  186. else:
  187. memory_vars.pop(0)
  188. memory_vars.append(new_vars.reshape(-1, 1))
  189. return memory_vars
  190. def _compute_hess_matrix(self):
  191. # LOGGER.debug("opt_v: {}, opt_s: {}".format(self.opt_v, self.opt_s))
  192. rho = sum(self.opt_v[-1] * self.opt_s[-1]) / sum(self.opt_v[-1] * self.opt_v[-1])
  193. # LOGGER.debug("in _compute_hess_matrix, rho0 = {}".format(rho))
  194. n = self.opt_s[0].shape[0]
  195. Hess = rho * np.identity(n)
  196. iter_num = 0
  197. for y, s in zip(self.opt_v, self.opt_s):
  198. rho = 1.0 / (y.T.dot(s))
  199. Hess = (np.identity(n) - rho * s.dot(y.T)).dot(Hess).dot(np.identity(n) - rho * y.dot(s.T)) + rho * s.dot(
  200. s.T)
  201. iter_num += 1
  202. # LOGGER.info(
  203. # "hessian updating algorithm iter_num = {}, rho = {} \n ||s|| is {} \n ||y|| is {}".format(iter_num, rho,
  204. # np.linalg.norm(
  205. # s),
  206. # np.linalg.norm(
  207. # y)))
  208. self.opt_Hess = Hess
  209. def sqn_factory(role, sqn_param):
  210. if role == consts.GUEST:
  211. return HeteroStochasticQuansiNewtonGuest(sqn_param)
  212. if role == consts.HOST:
  213. return HeteroStochasticQuansiNewtonHost(sqn_param)
  214. return HeteroStochasticQuansiNewtonArbiter(sqn_param)