hetero_lr_gradient_and_loss.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.framework.hetero.sync import loss_sync
  19. from federatedml.optim.gradient import hetero_linear_model_gradient
  20. from federatedml.util import LOGGER
  21. from federatedml.util.fate_operator import reduce_add, vec_dot
  22. class Guest(hetero_linear_model_gradient.Guest, loss_sync.Guest):
  23. def register_gradient_procedure(self, transfer_variables):
  24. self._register_gradient_sync(transfer_variables.host_forward_dict,
  25. transfer_variables.fore_gradient,
  26. transfer_variables.guest_gradient,
  27. transfer_variables.guest_optim_gradient)
  28. self._register_loss_sync(transfer_variables.host_loss_regular,
  29. transfer_variables.loss,
  30. transfer_variables.loss_intermediate)
  31. def compute_half_d(self, data_instances, w, cipher, batch_index, current_suffix):
  32. if self.use_sample_weight:
  33. self.half_d = data_instances.mapValues(
  34. lambda v: 0.25 * (vec_dot(v.features, w.coef_) + w.intercept_) * v.weight - 0.5 * v.label * v.weight)
  35. else:
  36. self.half_d = data_instances.mapValues(
  37. lambda v: 0.25 * (vec_dot(v.features, w.coef_) + w.intercept_) - 0.5 * v.label)
  38. # encrypted_half_d = cipher[batch_index].encrypt(self.half_d)
  39. # self.fore_gradient_transfer.remote(encrypted_half_d, suffix=current_suffix)
  40. return self.half_d
  41. def compute_and_aggregate_forwards(self, data_instances, half_g, encrypted_half_g, batch_index,
  42. current_suffix, offset=None):
  43. """
  44. gradient = (1/N)*∑(1/2*ywx-1)*1/2yx = (1/N)*∑(0.25 * wx - 0.5 * y) * x, where y = 1 or -1
  45. Define wx as guest_forward or host_forward
  46. Define (0.25 * wx - 0.5 * y) as fore_gradient
  47. """
  48. self.host_forwards = self.get_host_forward(suffix=current_suffix)
  49. # fore_gradient = half_g
  50. # for host_forward in self.host_forwards:
  51. # fore_gradient = fore_gradient.join(host_forward, lambda g, h: g + h)
  52. # fore_gradient = self.aggregated_forwards.join(data_instances, lambda wx, d: 0.25 * wx - 0.5 * d.label)
  53. return self.host_forwards
  54. def compute_loss(self, data_instances, w, n_iter_, batch_index, loss_norm=None, batch_masked=False):
  55. """
  56. Compute hetero-lr loss for:
  57. loss = (1/N)*∑(log2 - 1/2*ywx + 1/8*(wx)^2), where y is label, w is model weight and x is features
  58. where (wx)^2 = (Wg * Xg + Wh * Xh)^2 = (Wg*Xg)^2 + (Wh*Xh)^2 + 2 * Wg*Xg * Wh*Xh
  59. Then loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)]
  60. where Wh*Xh is a table obtain from host and ∑(Wh*Xh)^2 is a sum number get from host.
  61. """
  62. current_suffix = (n_iter_, batch_index)
  63. n = data_instances.count()
  64. # host_wx_y = self.host_forwards[0].join(data_instances, lambda x, y: (x, y.label))
  65. host_wx_y = data_instances.join(self.host_forwards[0], lambda y, x: (x, y.label))
  66. self_wx_y = self.half_d.join(data_instances, lambda x, y: (x, y.label))
  67. def _sum_ywx(wx_y):
  68. sum1, sum2 = 0, 0
  69. for _, (x, y) in wx_y:
  70. if y == 1:
  71. sum1 += x
  72. else:
  73. sum2 -= x
  74. return sum1 + sum2
  75. ywx = host_wx_y.applyPartitions(_sum_ywx).reduce(reduce_add) + \
  76. self_wx_y.applyPartitions(_sum_ywx).reduce(reduce_add)
  77. ywx = ywx * 4 + 2 * n
  78. # quarter_wx = self.host_forwards[0].join(self.half_d, lambda x, y: x + y)
  79. # ywx = quarter_wx.join(data_instances, lambda wx, d: wx * (4 * d.label) + 2).reduce(reduce_add)
  80. half_wx = data_instances.mapValues(
  81. lambda v: vec_dot(v.features, w.coef_) + w.intercept_)
  82. self_wx_square = half_wx.mapValues(
  83. lambda v: np.square(v)).reduce(reduce_add)
  84. # self_wx_square = data_instances.mapValues(
  85. # lambda v: np.square(vec_dot(v.features, w.coef_) + w.intercept_)).reduce(reduce_add)
  86. loss_list = []
  87. wx_squares = self.get_host_loss_intermediate(suffix=current_suffix)
  88. if batch_masked:
  89. wx_squares_sum = []
  90. for square_table in wx_squares:
  91. square_sum = data_instances.join(
  92. square_table,
  93. lambda inst,
  94. enc_h_squares: enc_h_squares).reduce(
  95. lambda x,
  96. y: x + y)
  97. wx_squares_sum.append(square_sum)
  98. wx_squares = wx_squares_sum
  99. if loss_norm is not None:
  100. host_loss_regular = self.get_host_loss_regular(suffix=current_suffix)
  101. else:
  102. host_loss_regular = []
  103. # for host_idx, host_forward in enumerate(self.host_forwards):
  104. if len(self.host_forwards) > 1:
  105. LOGGER.info("More than one host exist, loss is not available")
  106. else:
  107. host_forward = self.host_forwards[0]
  108. wx_square = wx_squares[0]
  109. wxg_wxh = half_wx.join(host_forward, lambda wxg, wxh: wxg * wxh).reduce(reduce_add)
  110. loss = np.log(2) - 0.5 * (1 / n) * ywx + 0.125 * (1 / n) * \
  111. (self_wx_square + wx_square + 8 * wxg_wxh)
  112. if loss_norm is not None:
  113. loss += loss_norm
  114. loss += host_loss_regular[0]
  115. loss_list.append(loss)
  116. LOGGER.debug("In compute_loss, loss list are: {}".format(loss_list))
  117. self.sync_loss_info(loss_list, suffix=current_suffix)
  118. def compute_forward_hess(self, data_instances, delta_s, host_forwards):
  119. """
  120. To compute Hessian matrix, y, s are needed.
  121. g = (1/N)*∑(0.25 * wx - 0.5 * y) * x
  122. y = ∇2^F(w_t)s_t = g' * s = (1/N)*∑(0.25 * x * s) * x
  123. define forward_hess = (1/N)*∑(0.25 * x * s)
  124. """
  125. forwards = data_instances.mapValues(
  126. lambda v: (vec_dot(v.features, delta_s.coef_) + delta_s.intercept_) * 0.25)
  127. for host_forward in host_forwards:
  128. forwards = forwards.join(host_forward, lambda g, h: g + (h * 0.25))
  129. if self.use_sample_weight:
  130. forwards = forwards.join(data_instances, lambda h, d: h * d.weight)
  131. # forward_hess = forwards.mapValues(lambda x: 0.25 * x / sample_size)
  132. hess_vector = self.compute_gradient(data_instances,
  133. forwards,
  134. delta_s.fit_intercept)
  135. return forwards, np.array(hess_vector)
  136. class Host(hetero_linear_model_gradient.Host, loss_sync.Host):
  137. def register_gradient_procedure(self, transfer_variables):
  138. self._register_gradient_sync(transfer_variables.host_forward_dict,
  139. transfer_variables.fore_gradient,
  140. transfer_variables.host_gradient,
  141. transfer_variables.host_optim_gradient)
  142. self._register_loss_sync(transfer_variables.host_loss_regular,
  143. transfer_variables.loss,
  144. transfer_variables.loss_intermediate)
  145. def compute_forwards(self, data_instances, model_weights):
  146. """
  147. forwards = 1/4 * wx
  148. """
  149. # wx = data_instances.mapValues(lambda v: vec_dot(v.features, model_weights.coef_) + model_weights.intercept_)
  150. self.forwards = data_instances.mapValues(lambda v: 0.25 * vec_dot(v.features, model_weights.coef_))
  151. return self.forwards
  152. def compute_half_g(self, data_instances, w, cipher, batch_index):
  153. half_g = data_instances.mapValues(
  154. lambda v: vec_dot(v.features, w.coef_) * 0.25 + w.intercept_)
  155. encrypt_half_g = cipher[batch_index].encrypt(half_g)
  156. return half_g, encrypt_half_g
  157. def compute_loss(self, lr_weights, optimizer, n_iter_, batch_index, cipher_operator, batch_masked=False):
  158. """
  159. Compute hetero-lr loss for:
  160. loss = (1/N)*∑(log2 - 1/2*ywx + 1/8*(wx)^2), where y is label, w is model weight and x is features
  161. where (wx)^2 = (Wg * Xg + Wh * Xh)^2 = (Wg*Xg)^2 + (Wh*Xh)^2 + 2 * Wg*Xg * Wh*Xh
  162. Then loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)]
  163. where Wh*Xh is a table obtain from host and ∑(Wh*Xh)^2 is a sum number get from host.
  164. """
  165. current_suffix = (n_iter_, batch_index)
  166. # self_wx_square = self.forwards.mapValues(lambda x: np.square(4 * x)).reduce(reduce_add)
  167. self_wx_square = self.forwards.mapValues(lambda x: np.square(4 * x))
  168. if not batch_masked:
  169. self_wx_square = self_wx_square.reduce(reduce_add)
  170. en_wx_square = cipher_operator.encrypt(self_wx_square)
  171. else:
  172. en_wx_square = self_wx_square.mapValues(lambda x: cipher_operator.encrypt(x))
  173. self.remote_loss_intermediate(en_wx_square, suffix=current_suffix)
  174. loss_regular = optimizer.loss_norm(lr_weights)
  175. if loss_regular is not None:
  176. en_loss_regular = cipher_operator.encrypt(loss_regular)
  177. self.remote_loss_regular(en_loss_regular, suffix=current_suffix)
  178. class Arbiter(hetero_linear_model_gradient.Arbiter, loss_sync.Arbiter):
  179. def register_gradient_procedure(self, transfer_variables):
  180. self._register_gradient_sync(transfer_variables.guest_gradient,
  181. transfer_variables.host_gradient,
  182. transfer_variables.guest_optim_gradient,
  183. transfer_variables.host_optim_gradient)
  184. self._register_loss_sync(transfer_variables.loss)
  185. def compute_loss(self, cipher, n_iter_, batch_index):
  186. """
  187. Compute hetero-lr loss for:
  188. loss = (1/N)*∑(log2 - 1/2*ywx + 1/8*(wx)^2), where y is label, w is model weight and x is features
  189. where (wx)^2 = (Wg * Xg + Wh * Xh)^2 = (Wg*Xg)^2 + (Wh*Xh)^2 + 2 * Wg*Xg * Wh*Xh
  190. Then loss = log2 - (1/N)*0.5*∑ywx + (1/N)*0.125*[∑(Wg*Xg)^2 + ∑(Wh*Xh)^2 + 2 * ∑(Wg*Xg * Wh*Xh)]
  191. where Wh*Xh is a table obtain from host and ∑(Wh*Xh)^2 is a sum number get from host.
  192. """
  193. if self.has_multiple_hosts:
  194. LOGGER.info("Has more than one host, loss is not available")
  195. return []
  196. current_suffix = (n_iter_, batch_index)
  197. loss_list = self.sync_loss_info(suffix=current_suffix)
  198. de_loss_list = cipher.decrypt_list(loss_list)
  199. return de_loss_list