hetero_linr_gradient_and_loss.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.framework.hetero.sync import loss_sync
  19. from federatedml.optim.gradient import hetero_linear_model_gradient
  20. from federatedml.util import LOGGER
  21. from federatedml.util.fate_operator import reduce_add, vec_dot
  22. class Guest(hetero_linear_model_gradient.Guest, loss_sync.Guest):
  23. def register_gradient_procedure(self, transfer_variables):
  24. self._register_gradient_sync(transfer_variables.host_forward,
  25. transfer_variables.fore_gradient,
  26. transfer_variables.guest_gradient,
  27. transfer_variables.guest_optim_gradient)
  28. self._register_loss_sync(transfer_variables.host_loss_regular,
  29. transfer_variables.loss,
  30. transfer_variables.loss_intermediate)
  31. def compute_half_d(self, data_instances, w, cipher, batch_index, current_suffix):
  32. if self.use_sample_weight:
  33. self.half_d = data_instances.mapValues(
  34. lambda v: (vec_dot(v.features, w.coef_) + w.intercept_ - v.label) * v.weight)
  35. else:
  36. self.half_d = data_instances.mapValues(
  37. lambda v: vec_dot(v.features, w.coef_) + w.intercept_ - v.label)
  38. return self.half_d
  39. def compute_and_aggregate_forwards(self, data_instances, half_g, encrypted_half_g, batch_index,
  40. current_suffix, offset=None):
  41. """
  42. gradient = (1/N)*sum(wx - y) * x
  43. Define wx -y as guest_forward and wx as host_forward
  44. """
  45. self.host_forwards = self.get_host_forward(suffix=current_suffix)
  46. return self.host_forwards
  47. def compute_loss(self, data_instances, n_iter_, batch_index, loss_norm=None):
  48. '''
  49. Compute hetero linr loss:
  50. loss = (1/N)*\\sum(wx-y)^2 where y is label, w is model weight and x is features
  51. log(wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + wx_g - y)
  52. '''
  53. current_suffix = (n_iter_, batch_index)
  54. n = data_instances.count()
  55. loss_list = []
  56. host_wx_squares = self.get_host_loss_intermediate(current_suffix)
  57. if loss_norm is not None:
  58. host_loss_regular = self.get_host_loss_regular(suffix=current_suffix)
  59. else:
  60. host_loss_regular = []
  61. if len(self.host_forwards) > 1:
  62. LOGGER.info("More than one host exist, loss is not available")
  63. else:
  64. host_forward = self.host_forwards[0]
  65. host_wx_square = host_wx_squares[0]
  66. wxy_square = self.half_d.mapValues(lambda x: np.square(x)).reduce(reduce_add)
  67. loss_gh = self.half_d.join(host_forward, lambda g, h: g * h).reduce(reduce_add)
  68. loss = (wxy_square + host_wx_square + 2 * loss_gh) / (2 * n)
  69. if loss_norm is not None:
  70. loss = loss + loss_norm + host_loss_regular[0]
  71. loss_list.append(loss)
  72. # LOGGER.debug("In compute_loss, loss list are: {}".format(loss_list))
  73. self.sync_loss_info(loss_list, suffix=current_suffix)
  74. def compute_forward_hess(self, data_instances, delta_s, host_forwards):
  75. """
  76. To compute Hessian matrix, y, s are needed.
  77. g = (1/N)*∑(wx - y) * x
  78. y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(x * s) * x
  79. define forward_hess = (1/N)*∑(x * s)
  80. """
  81. forwards = data_instances.mapValues(
  82. lambda v: (vec_dot(v.features, delta_s.coef_) + delta_s.intercept_))
  83. for host_forward in host_forwards:
  84. forwards = forwards.join(host_forward, lambda g, h: g + h)
  85. if self.use_sample_weight:
  86. forwards = forwards.join(data_instances, lambda h, d: h * d.weight)
  87. hess_vector = self.compute_gradient(data_instances,
  88. forwards,
  89. delta_s.fit_intercept)
  90. return forwards, np.array(hess_vector)
  91. class Host(hetero_linear_model_gradient.Host, loss_sync.Host):
  92. def register_gradient_procedure(self, transfer_variables):
  93. self._register_gradient_sync(transfer_variables.host_forward,
  94. transfer_variables.fore_gradient,
  95. transfer_variables.host_gradient,
  96. transfer_variables.host_optim_gradient)
  97. self._register_loss_sync(transfer_variables.host_loss_regular,
  98. transfer_variables.loss,
  99. transfer_variables.loss_intermediate)
  100. def compute_forwards(self, data_instances, model_weights):
  101. wx = data_instances.mapValues(
  102. lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
  103. return wx
  104. def compute_half_g(self, data_instances, w, cipher, batch_index):
  105. half_g = data_instances.mapValues(
  106. lambda v: vec_dot(v.features, w.coef_) + w.intercept_)
  107. encrypt_half_g = cipher[batch_index].encrypt(half_g)
  108. return half_g, encrypt_half_g
  109. def compute_loss(self, model_weights, optimizer, n_iter_, batch_index, cipher_operator):
  110. '''
  111. Compute htero linr loss for:
  112. loss = (1/2N)*\\sum(wx-y)^2 where y is label, w is model weight and x is features
  113. Note: (wx - y)^2 = (wx_h)^2 + (wx_g - y)^2 + 2*(wx_h + (wx_g - y))
  114. '''
  115. current_suffix = (n_iter_, batch_index)
  116. self_wx_square = self.forwards.mapValues(lambda x: np.square(x)).reduce(reduce_add)
  117. en_wx_square = cipher_operator.encrypt(self_wx_square)
  118. self.remote_loss_intermediate(en_wx_square, suffix=current_suffix)
  119. loss_regular = optimizer.loss_norm(model_weights)
  120. if loss_regular is not None:
  121. en_loss_regular = cipher_operator.encrypt(loss_regular)
  122. self.remote_loss_regular(en_loss_regular, suffix=current_suffix)
  123. class Arbiter(hetero_linear_model_gradient.Arbiter, loss_sync.Arbiter):
  124. def register_gradient_procedure(self, transfer_variables):
  125. self._register_gradient_sync(transfer_variables.guest_gradient,
  126. transfer_variables.host_gradient,
  127. transfer_variables.guest_optim_gradient,
  128. transfer_variables.host_optim_gradient)
  129. self._register_loss_sync(transfer_variables.loss)
  130. def compute_loss(self, cipher, n_iter_, batch_index):
  131. """
  132. Decrypt loss from guest
  133. """
  134. current_suffix = (n_iter_, batch_index)
  135. loss_list = self.sync_loss_info(suffix=current_suffix)
  136. de_loss_list = cipher.decrypt_list(loss_list)
  137. return de_loss_list