hetero_binning_host.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import operator
  18. from federatedml.cipher_compressor.compressor import CipherCompressorHost
  19. from federatedml.feature.hetero_feature_binning.base_feature_binning import BaseFeatureBinning
  20. from federatedml.secureprotol.fate_paillier import PaillierEncryptedNumber
  21. from federatedml.util import LOGGER
  22. from federatedml.util import consts
  23. class HeteroFeatureBinningHost(BaseFeatureBinning):
  24. def __init__(self):
  25. super(HeteroFeatureBinningHost, self).__init__()
  26. self.compressor = None
  27. def fit(self, data_instances):
  28. self._abnormal_detection(data_instances)
  29. # self._parse_cols(data_instances)
  30. self._setup_bin_inner_param(data_instances, self.model_param)
  31. if self.model_param.method == consts.OPTIMAL:
  32. has_missing_value = self.iv_calculator.check_containing_missing_value(data_instances)
  33. for idx in self.bin_inner_param.bin_indexes:
  34. if idx in has_missing_value:
  35. raise ValueError(f"Optimal Binning do not support missing value now.")
  36. if self.model_param.split_points_by_col_name or self.model_param.split_points_by_index:
  37. split_points = self._get_manual_split_points(data_instances)
  38. self.use_manual_split_points = True
  39. for col_name, sp in split_points.items():
  40. self.binning_obj.bin_results.put_col_split_points(col_name, sp)
  41. else:
  42. # Calculates split points of data in self part
  43. split_points = self.binning_obj.fit_split_points(data_instances)
  44. return self.stat_and_transform(data_instances, split_points)
  45. def transform(self, data_instances):
  46. self._setup_bin_inner_param(data_instances, self.model_param)
  47. split_points = self.binning_obj.bin_results.all_split_points
  48. return self.stat_and_transform(data_instances, split_points)
  49. def stat_and_transform(self, data_instances, split_points):
  50. """
  51. Apply binning method for both data instances in local party as well as the other one. Afterwards, calculate
  52. the specific metric value for specific columns.
  53. """
  54. if self.model_param.skip_static:
  55. # if self.transform_type != 'woe':
  56. data_instances = self.transform_data(data_instances)
  57. """
  58. else:
  59. raise ValueError("Woe transform is not supported in host parties.")
  60. """
  61. self.set_schema(data_instances)
  62. self.data_output = data_instances
  63. return data_instances
  64. if not self.model_param.local_only:
  65. has_label = True
  66. if self._stage == "transform":
  67. has_label = self.transfer_variable.transform_stage_has_label.get(idx=0)
  68. if has_label:
  69. self.compressor = CipherCompressorHost()
  70. self._sync_init_bucket(data_instances, split_points)
  71. if self.model_param.method == consts.OPTIMAL and self._stage == "fit":
  72. self.optimal_binning_sync()
  73. # if self.transform_type != 'woe':
  74. data_instances = self.transform_data(data_instances)
  75. self.set_schema(data_instances)
  76. self.data_output = data_instances
  77. total_summary = self.binning_obj.bin_results.to_json()
  78. self.set_summary(total_summary)
  79. return data_instances
  80. def _sync_init_bucket(self, data_instances, split_points, need_shuffle=False):
  81. data_bin_table = self.binning_obj.get_data_bin(data_instances, split_points, self.bin_inner_param.bin_cols_map)
  82. # LOGGER.debug("data_bin_table, count: {}".format(data_bin_table.count()))
  83. encrypted_label_table = self.transfer_variable.encrypted_label.get(idx=0)
  84. LOGGER.info("Get encrypted_label_table from guest")
  85. encrypted_bin_sum = self.__static_encrypted_bin_label(data_bin_table, encrypted_label_table)
  86. encrypted_bin_sum = self.compressor.compress_dtable(encrypted_bin_sum)
  87. encode_name_f = functools.partial(self.bin_inner_param.change_to_anonymous,
  88. col_name_anonymous_maps=self.bin_inner_param.col_name_anonymous_maps)
  89. # encrypted_bin_sum = self.bin_inner_param.encode_col_name_dict(encrypted_bin_sum, self)
  90. encrypted_bin_sum = encrypted_bin_sum.map(encode_name_f)
  91. # encrypted_bin_sum = self.cipher_compress(encrypted_bin_sum, data_bin_table.count())
  92. self.transfer_variable.encrypted_bin_sum.remote(encrypted_bin_sum,
  93. role=consts.GUEST,
  94. idx=0)
  95. send_result = {
  96. "category_names": self.bin_inner_param.get_anonymous_col_name_list(
  97. self.bin_inner_param.category_names),
  98. "bin_method": self.model_param.method,
  99. "optimal_params": {
  100. "metric_method": self.model_param.optimal_binning_param.metric_method,
  101. "bin_num": self.model_param.bin_num,
  102. "mixture": self.model_param.optimal_binning_param.mixture,
  103. "max_bin_pct": self.model_param.optimal_binning_param.max_bin_pct,
  104. "min_bin_pct": self.model_param.optimal_binning_param.min_bin_pct
  105. }
  106. }
  107. self.transfer_variable.optimal_info.remote(send_result,
  108. role=consts.GUEST,
  109. idx=0)
  110. def __static_encrypted_bin_label(self, data_bin_table, encrypted_label):
  111. # data_bin_with_label = data_bin_table.join(encrypted_label, lambda x, y: (x, y))
  112. label_counts = encrypted_label.reduce(operator.add)
  113. sparse_bin_points = self.binning_obj.get_sparse_bin(self.bin_inner_param.bin_indexes,
  114. self.binning_obj.bin_results.all_split_points,
  115. self.bin_inner_param.header)
  116. sparse_bin_points = {self.bin_inner_param.header[k]: v for k, v in sparse_bin_points.items()}
  117. encrypted_bin_sum = self.iv_calculator.cal_bin_label(
  118. data_bin_table=data_bin_table,
  119. sparse_bin_points=sparse_bin_points,
  120. label_table=encrypted_label,
  121. label_counts=label_counts
  122. )
  123. return encrypted_bin_sum
  124. @staticmethod
  125. def convert_compress_format(col_name, encrypted_bin_sum):
  126. """
  127. Parameters
  128. ----------
  129. encrypted_bin_sum : list.
  130. It is like:
  131. {'x1': [[event_count, non_event_count], [event_count, non_event_count] ... ],
  132. 'x2': [[event_count, non_event_count], [event_count, non_event_count] ... ],
  133. ...
  134. }
  135. returns
  136. -------
  137. {"keys": ['x1', 'x2' ...],
  138. "event_counts": [...],
  139. "non_event_counts": [...],
  140. "bin_num": [...]
  141. }
  142. """
  143. event_counts = [x[0] for x in encrypted_bin_sum]
  144. non_event_counts = [x[1] for x in encrypted_bin_sum]
  145. return col_name, {"event_counts": event_counts, "non_event_counts": non_event_counts}
  146. def optimal_binning_sync(self):
  147. bucket_idx = self.transfer_variable.bucket_idx.get(idx=0)
  148. # LOGGER.debug("In optimal_binning_sync, received bucket_idx: {}".format(bucket_idx))
  149. original_split_points = self.binning_obj.bin_results.all_split_points
  150. for anonymous_col_name, b_idx in bucket_idx.items():
  151. col_name = self.bin_inner_param.get_col_name_by_anonymous(anonymous_col_name)
  152. ori_sp_list = original_split_points.get(col_name)
  153. optimal_result = [ori_sp_list[i] for i in b_idx]
  154. self.binning_obj.bin_results.put_col_split_points(col_name, optimal_result)