homo_binning_base.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import numpy as np
  18. from federatedml.feature.binning.base_binning import BaseBinning
  19. from federatedml.framework import weights
  20. from fate_arch.session import computing_session as session
  21. from federatedml.param.feature_binning_param import HomoFeatureBinningParam
  22. from federatedml.statistic.data_statistics import MultivariateStatisticalSummary
  23. from federatedml.transfer_variable.transfer_class.homo_binning_transfer_variable import HomoBinningTransferVariable
  24. from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
  25. from federatedml.util import consts
  26. class SplitPointNode(object):
  27. def __init__(self, value, min_value, max_value, aim_rank=None, allow_error_rank=0, last_rank=-1):
  28. self.value = value
  29. self.min_value = min_value
  30. self.max_value = max_value
  31. self.aim_rank = aim_rank
  32. self.allow_error_rank = allow_error_rank
  33. self.last_rank = last_rank
  34. self.fixed = False
  35. def set_aim_rank(self, rank):
  36. self.aim_rank = rank
  37. def create_right_new(self):
  38. value = (self.value + self.max_value) / 2
  39. if np.fabs(value - self.value) <= consts.FLOAT_ZERO * 0.9:
  40. self.value += consts.FLOAT_ZERO * 0.9
  41. self.fixed = True
  42. return self
  43. min_value = self.value
  44. return SplitPointNode(value, min_value, self.max_value, self.aim_rank, self.allow_error_rank)
  45. def create_left_new(self):
  46. value = (self.value + self.min_value) / 2
  47. if np.fabs(value - self.value) <= consts.FLOAT_ZERO * 0.9:
  48. self.value += consts.FLOAT_ZERO * 0.9
  49. self.fixed = True
  50. return self
  51. max_value = self.value
  52. return SplitPointNode(value, self.min_value, max_value, self.aim_rank, self.allow_error_rank)
  53. class RankArray(object):
  54. def __init__(self, rank_array, error_rank, last_rank_array=None):
  55. self.rank_array = rank_array
  56. self.last_rank_array = last_rank_array
  57. self.error_rank = error_rank
  58. self.all_fix = False
  59. self.fixed_array = np.zeros(len(self.rank_array), dtype=bool)
  60. self._compare()
  61. def _compare(self):
  62. if self.last_rank_array is None:
  63. return
  64. else:
  65. self.fixed_array = abs(self.rank_array - self.last_rank_array) < self.error_rank
  66. assert isinstance(self.fixed_array, np.ndarray)
  67. if (self.fixed_array).all():
  68. self.all_fix = True
  69. def __iadd__(self, other: 'RankArray'):
  70. for idx, is_fixed in enumerate(self.fixed_array):
  71. if not is_fixed:
  72. self.rank_array[idx] += other.rank_array[idx]
  73. self._compare()
  74. return self
  75. def __add__(self, other: 'RankArray'):
  76. res_array = []
  77. for idx, is_fixed in enumerate(self.fixed_array):
  78. if not is_fixed:
  79. res_array.append(self.rank_array[idx] + other.rank_array[idx])
  80. else:
  81. res_array.append(self.rank_array[idx])
  82. return RankArray(np.array(res_array), self.error_rank, self.last_rank_array)
  83. class Server(BaseBinning):
  84. def __init__(self, params=None, abnormal_list=None):
  85. super().__init__(params, abnormal_list)
  86. self.aggregator: SecureAggregatorServer = None
  87. self.transfer_variable = HomoBinningTransferVariable()
  88. self.suffix = None
  89. def set_suffix(self, suffix):
  90. self.suffix = suffix
  91. def set_transfer_variable(self, variable):
  92. self.transfer_variable = variable
  93. def set_aggregator(self, aggregator):
  94. self.aggregator = aggregator
  95. def get_total_count(self):
  96. # total_count = self.aggregator.sum_model(suffix=(self.suffix, 'total_count'))
  97. # self.aggregator.send_aggregated_model(total_count, suffix=(self.suffix, 'total_count'))
  98. total_count = self.aggregator.aggregate_model(suffix=(self.suffix, 'total_count'))
  99. self.aggregator.broadcast_model(total_count, suffix=(self.suffix, 'total_count'))
  100. return total_count
  101. def get_missing_count(self):
  102. # missing_count = self.aggregator.sum_model(suffix=(self.suffix, 'missing_count'))
  103. # self.aggregator.send_aggregated_model(missing_count, suffix=(self.suffix, 'missing_count'))
  104. missing_count = self.aggregator.aggregate_model(suffix=(self.suffix, 'missing_count'))
  105. self.aggregator.broadcast_model(missing_count, suffix=(self.suffix, 'missing_count'))
  106. return missing_count
  107. def get_min_max(self):
  108. local_values = self.transfer_variable.local_static_values.get(suffix=(self.suffix, "min-max"))
  109. max_array, min_array = [], []
  110. for local_max, local_min in local_values:
  111. max_array.append(local_max)
  112. min_array.append(local_min)
  113. max_values = np.max(max_array, axis=0)
  114. min_values = np.min(min_array, axis=0)
  115. self.transfer_variable.global_static_values.remote((max_values, min_values),
  116. suffix=(self.suffix, "min-max"))
  117. return min_values, max_values
  118. def query_values(self):
  119. # rank_weight = self.aggregator.aggregate_tables(suffix=(self.suffix, 'rank'))
  120. # self.aggregator.send_aggregated_tables(rank_weight, suffix=(self.suffix, 'rank'))
  121. rank_weight = self.aggregator.aggregate_model(suffix=(self.suffix, 'rank'))
  122. self.aggregator.broadcast_model(rank_weight, suffix=(self.suffix, 'rank'))
  123. class Client(BaseBinning):
  124. def __init__(self, params: HomoFeatureBinningParam = None, abnormal_list=None):
  125. super().__init__(params, abnormal_list)
  126. self.aggregator: SecureAggregatorClient = None
  127. self.transfer_variable = HomoBinningTransferVariable()
  128. self.max_values, self.min_values = None, None
  129. self.suffix = None
  130. self.total_count = 0
  131. def set_suffix(self, suffix):
  132. self.suffix = suffix
  133. def set_transfer_variable(self, variable):
  134. self.transfer_variable = variable
  135. def set_aggregator(self, aggregator):
  136. self.aggregator = aggregator
  137. def get_total_count(self, data_inst):
  138. count = data_inst.count()
  139. count_weight = weights.NumericWeights(count)
  140. self.aggregator.send_model(count_weight, suffix=(self.suffix, 'total_count'))
  141. total_count = self.aggregator.get_aggregated_model(suffix=(self.suffix, 'total_count')).unboxed
  142. return total_count
  143. def get_missing_count(self, summary_table):
  144. missing_table = summary_table.mapValues(lambda x: x.missing_count)
  145. missing_value_counts = dict(missing_table.collect())
  146. missing_weight = weights.DictWeights(missing_value_counts)
  147. self.aggregator.send_model(missing_weight, suffix=(self.suffix, 'missing_count'))
  148. missing_counts = self.aggregator.get_aggregated_model(suffix=(self.suffix, 'missing_count')).unboxed
  149. return missing_counts
  150. def get_min_max(self, data_inst):
  151. """
  152. Get max and min value of each selected columns
  153. Returns:
  154. max_values, min_values: dict
  155. eg. {"x1": 10, "x2": 3, ... }
  156. """
  157. if self.max_values and self.min_values:
  158. return self.max_values, self.min_values
  159. statistic_obj = MultivariateStatisticalSummary(data_inst,
  160. cols_index=self.bin_inner_param.bin_indexes,
  161. abnormal_list=self.abnormal_list,
  162. error=self.params.error)
  163. max_values = statistic_obj.get_max()
  164. min_values = statistic_obj.get_min()
  165. max_list = [max_values[x] for x in self.bin_inner_param.bin_names]
  166. min_list = [min_values[x] for x in self.bin_inner_param.bin_names]
  167. local_min_max_values = (max_list, min_list)
  168. self.transfer_variable.local_static_values.remote(local_min_max_values,
  169. suffix=(self.suffix, "min-max"))
  170. self.max_values, self.min_values = self.transfer_variable.global_static_values.get(
  171. idx=0, suffix=(self.suffix, "min-max"))
  172. return self.max_values, self.min_values
  173. def init_query_points(self, partitions, split_num, error_rank=1, need_first=True):
  174. query_points = []
  175. for idx, col_name in enumerate(self.bin_inner_param.bin_names):
  176. max_value = self.max_values[idx]
  177. min_value = self.min_values[idx]
  178. sps = np.linspace(min_value, max_value, split_num)
  179. if not need_first:
  180. sps = sps[1:]
  181. split_point_array = [SplitPointNode(sps[i], min_value, max_value, allow_error_rank=error_rank)
  182. for i in range(len(sps))]
  183. query_points.append((col_name, split_point_array))
  184. query_points_table = session.parallelize(query_points,
  185. include_key=True,
  186. partition=partitions)
  187. return query_points_table
  188. def query_values(self, summary_table, query_points):
  189. local_ranks = summary_table.join(query_points, self._query_table)
  190. self.aggregator.send_model(local_ranks, suffix=(self.suffix, 'rank'))
  191. global_rank = self.aggregator.get_aggregated_model(suffix=(self.suffix, 'rank'))
  192. global_rank = global_rank.mapValues(lambda x: np.array(x, dtype=int))
  193. return global_rank
  194. @staticmethod
  195. def _query_table(summary, query_points):
  196. queries = [x.value for x in query_points]
  197. original_idx = np.argsort(np.argsort(queries))
  198. queries = np.sort(queries)
  199. ranks = summary.query_value_list(queries)
  200. ranks = np.array(ranks)[original_idx]
  201. return np.array(ranks, dtype=int)