hetero_poisson_gradient_and_loss.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.framework.hetero.sync import loss_sync
  19. from federatedml.optim.gradient import hetero_linear_model_gradient
  20. from federatedml.util.fate_operator import reduce_add, vec_dot
  21. class Guest(hetero_linear_model_gradient.Guest, loss_sync.Guest):
  22. def register_gradient_procedure(self, transfer_variables):
  23. self._register_gradient_sync(transfer_variables.host_forward,
  24. transfer_variables.fore_gradient,
  25. transfer_variables.guest_gradient,
  26. transfer_variables.guest_optim_gradient)
  27. self._register_loss_sync(transfer_variables.host_loss_regular,
  28. transfer_variables.loss,
  29. transfer_variables.loss_intermediate)
  30. def compute_gradient_procedure(self, data_instances, cipher, model_weights, optimizer,
  31. n_iter_, batch_index, offset=None):
  32. current_suffix = (n_iter_, batch_index)
  33. fore_gradient = self.compute_and_aggregate_forwards(data_instances, model_weights, cipher,
  34. batch_index, current_suffix, offset)
  35. self.remote_fore_gradient(fore_gradient, suffix=current_suffix)
  36. unilateral_gradient = self.compute_gradient(data_instances,
  37. fore_gradient,
  38. model_weights.fit_intercept)
  39. if optimizer is not None:
  40. unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
  41. optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
  42. return optimized_gradient
  43. def compute_and_aggregate_forwards(self, data_instances, model_weights, cipher,
  44. batch_index, current_suffix, offset=None):
  45. '''
  46. Compute gradients:
  47. gradient = (1/N) * \\sum(exp(wx) - y) * x
  48. Define exp(wx) as mu, named it as guest_forward or host_forward
  49. Define (mu-y) as fore_gradient
  50. Then, gradient = fore_gradient * x
  51. '''
  52. if offset is None:
  53. raise ValueError("Offset should be provided when compute poisson forwards")
  54. mu = data_instances.join(offset, lambda d, m: np.exp(vec_dot(d.features, model_weights.coef_)
  55. + model_weights.intercept_ + m))
  56. self.forwards = mu
  57. self.host_forwards = self.get_host_forward(suffix=current_suffix)
  58. self.aggregated_forwards = self.forwards.join(self.host_forwards[0], lambda g, h: g * h)
  59. fore_gradient = self.aggregated_forwards.join(data_instances, lambda mu, d: mu - d.label)
  60. return fore_gradient
  61. def compute_loss(self, data_instances, model_weights, n_iter_, batch_index, offset, loss_norm=None):
  62. '''
  63. Compute hetero poisson loss:
  64. loss = sum(exp(mu_g)*exp(mu_h) - y(wx_g + wx_h) + log(exposure))
  65. Parameters:
  66. ___________
  67. data_instances: Table, input data
  68. model_weights: model weight object, stores intercept_ and coef_
  69. n_iter_: int, current number of iter.
  70. batch_index: int, use to obtain current encrypted_calculator index
  71. offset: log(exposure)
  72. loss_norm: penalty term, default to None
  73. '''
  74. current_suffix = (n_iter_, batch_index)
  75. n = data_instances.count()
  76. guest_wx_y = data_instances.join(offset,
  77. lambda v, m: (
  78. vec_dot(v.features, model_weights.coef_) + model_weights.intercept_ + m,
  79. v.label))
  80. loss_list = []
  81. host_wxs = self.get_host_loss_intermediate(current_suffix)
  82. if loss_norm is not None:
  83. host_loss_regular = self.get_host_loss_regular(suffix=current_suffix)
  84. else:
  85. host_loss_regular = []
  86. if len(self.host_forwards) > 1:
  87. raise ValueError("More than one host exists. Poisson regression does not support multi-host.")
  88. host_mu = self.host_forwards[0]
  89. host_wx = host_wxs[0]
  90. loss_wx = guest_wx_y.join(host_wx, lambda g, h: g[1] * (g[0] + h)).reduce(reduce_add)
  91. loss_mu = self.forwards.join(host_mu, lambda g, h: g * h).reduce(reduce_add)
  92. loss = (loss_mu - loss_wx) / n
  93. if loss_norm is not None:
  94. loss = loss + loss_norm + host_loss_regular[0]
  95. loss_list.append(loss)
  96. self.sync_loss_info(loss_list, suffix=current_suffix)
  97. class Host(hetero_linear_model_gradient.Host, loss_sync.Host):
  98. def register_gradient_procedure(self, transfer_variables):
  99. self._register_gradient_sync(transfer_variables.host_forward,
  100. transfer_variables.fore_gradient,
  101. transfer_variables.host_gradient,
  102. transfer_variables.host_optim_gradient)
  103. self._register_loss_sync(transfer_variables.host_loss_regular,
  104. transfer_variables.loss,
  105. transfer_variables.loss_intermediate)
  106. def compute_gradient_procedure(self, data_instances, cipher, model_weights,
  107. optimizer,
  108. n_iter_, batch_index):
  109. """
  110. Linear model gradient procedure
  111. Step 1: get host forwards which differ for different algorithms
  112. """
  113. current_suffix = (n_iter_, batch_index)
  114. self.forwards = self.compute_forwards(data_instances, model_weights)
  115. encrypted_forward = cipher.distribute_encrypt(self.forwards)
  116. self.remote_host_forward(encrypted_forward, suffix=current_suffix)
  117. fore_gradient = self.get_fore_gradient(suffix=current_suffix)
  118. unilateral_gradient = self.compute_gradient(data_instances,
  119. fore_gradient,
  120. model_weights.fit_intercept)
  121. if optimizer is not None:
  122. unilateral_gradient = optimizer.add_regular_to_grad(unilateral_gradient, model_weights)
  123. optimized_gradient = self.update_gradient(unilateral_gradient, suffix=current_suffix)
  124. return optimized_gradient
  125. def compute_forwards(self, data_instances, model_weights):
  126. mu = data_instances.mapValues(
  127. lambda v: np.exp(vec_dot(v.features, model_weights.coef_) + model_weights.intercept_))
  128. return mu
  129. def compute_loss(self, data_instances, model_weights,
  130. optimizer, n_iter_, batch_index, cipher):
  131. '''
  132. Compute hetero poisson loss:
  133. h_loss = sum(exp(mu_h))
  134. Parameters:
  135. ___________
  136. data_instances: Table, input data
  137. model_weights: model weight object, stores intercept_ and coef_
  138. optimizer: optimizer object
  139. n_iter_: int, current number of iter.
  140. cipher: cipher for encrypt intermediate loss and loss_regular
  141. '''
  142. current_suffix = (n_iter_, batch_index)
  143. self_wx = data_instances.mapValues(
  144. lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
  145. en_wx = cipher.distribute_encrypt(self_wx)
  146. self.remote_loss_intermediate(en_wx, suffix=current_suffix)
  147. loss_regular = optimizer.loss_norm(model_weights)
  148. if loss_regular is not None:
  149. en_loss_regular = cipher.encrypt(loss_regular)
  150. self.remote_loss_regular(en_loss_regular, suffix=current_suffix)
  151. class Arbiter(hetero_linear_model_gradient.Arbiter, loss_sync.Arbiter):
  152. def register_gradient_procedure(self, transfer_variables):
  153. self._register_gradient_sync(transfer_variables.guest_gradient,
  154. transfer_variables.host_gradient,
  155. transfer_variables.guest_optim_gradient,
  156. transfer_variables.host_optim_gradient)
  157. self._register_loss_sync(transfer_variables.loss)
  158. def compute_loss(self, cipher, n_iter_, batch_index):
  159. '''
  160. Decrypt loss from guest
  161. '''
  162. current_suffix = (n_iter_, batch_index)
  163. loss_list = self.sync_loss_info(suffix=current_suffix)
  164. de_loss_list = cipher.decrypt_list(loss_list)
  165. return de_loss_list