recursive_query_binning.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import numpy as np
  16. from federatedml.feature.binning.quantile_tool import QuantileBinningTool
  17. from federatedml.feature.homo_feature_binning import homo_binning_base
  18. from federatedml.param.feature_binning_param import HomoFeatureBinningParam
  19. from federatedml.util import consts
  20. from federatedml.util import LOGGER
  21. from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
  22. import copy
  23. import operator
  24. import functools
  25. class Server(homo_binning_base.Server):
  26. def __init__(self, params: HomoFeatureBinningParam, abnormal_list=None):
  27. super().__init__(params, abnormal_list)
  28. def fit_split_points(self, data=None):
  29. if self.aggregator is None:
  30. self.aggregator = SecureAggregatorServer(
  31. secure_aggregate=True, communicate_match_suffix='recursive_query_binning')
  32. self.get_total_count()
  33. self.get_min_max()
  34. self.get_missing_count()
  35. self.set_suffix(-1)
  36. self.query_values()
  37. n_iter = 0
  38. while n_iter < self.params.max_iter:
  39. self.set_suffix(n_iter)
  40. is_converge = self.transfer_variable.is_converge.get(suffix=self.suffix)[0]
  41. if is_converge:
  42. break
  43. self.query_values()
  44. n_iter += 1
  45. class Client(homo_binning_base.Client):
  46. def __init__(self, role, params: HomoFeatureBinningParam = None,
  47. abnormal_list=None, allow_duplicate=False):
  48. super().__init__(params, abnormal_list)
  49. self.allow_duplicate = allow_duplicate
  50. self.global_ranks = {}
  51. self.total_count = 0
  52. self.missing_counts = 0
  53. self.error = params.error
  54. self.error_rank = None
  55. self.role = role
  56. def fit_split_points(self, data_instances):
  57. if self.aggregator is None:
  58. self.aggregator = SecureAggregatorClient(
  59. secure_aggregate=True,
  60. aggregate_type='sum',
  61. communicate_match_suffix='recursive_query_binning')
  62. if self.bin_inner_param is None:
  63. self._setup_bin_inner_param(data_instances, self.params)
  64. self.total_count = self.get_total_count(data_instances)
  65. self.error_rank = np.ceil(self.error * self.total_count)
  66. LOGGER.debug(f"abnormal_list: {self.abnormal_list}")
  67. quantile_tool = QuantileBinningTool(param_obj=self.params,
  68. abnormal_list=self.abnormal_list,
  69. allow_duplicate=self.allow_duplicate)
  70. quantile_tool.set_bin_inner_param(self.bin_inner_param)
  71. summary_table = quantile_tool.fit_summary(data_instances)
  72. self.get_min_max(data_instances)
  73. self.missing_counts = self.get_missing_count(summary_table)
  74. split_points_table = self._recursive_querying(summary_table)
  75. split_points = dict(split_points_table.collect())
  76. for col_name, sps in split_points.items():
  77. sp = [x.value for x in sps]
  78. if not self.allow_duplicate:
  79. sp = sorted(set(sp))
  80. res = [sp[0] if np.fabs(sp[0]) > consts.FLOAT_ZERO else 0.0]
  81. last = sp[0]
  82. for v in sp[1:]:
  83. if np.fabs(v) < consts.FLOAT_ZERO:
  84. v = 0.0
  85. if np.abs(v - last) > consts.FLOAT_ZERO:
  86. res.append(v)
  87. last = v
  88. sp = np.array(res)
  89. self.bin_results.put_col_split_points(col_name, sp)
  90. return self.bin_results.all_split_points
  91. @staticmethod
  92. def _set_aim_rank(feature_name, split_point_array, missing_dict, total_counts, split_num):
  93. total_count = total_counts - missing_dict[feature_name]
  94. aim_ranks = [np.floor(x * total_count)
  95. for x in np.linspace(0, 1, split_num)]
  96. aim_ranks = aim_ranks[1:]
  97. for idx, sp in enumerate(split_point_array):
  98. sp.set_aim_rank(aim_ranks[idx])
  99. return feature_name, split_point_array
  100. def _recursive_querying(self, summary_table):
  101. self.set_suffix(-1)
  102. query_points_table = self.init_query_points(summary_table.partitions,
  103. split_num=self.params.bin_num + 1,
  104. error_rank=self.error_rank,
  105. need_first=False)
  106. f = functools.partial(self._set_aim_rank,
  107. missing_dict=self.missing_counts,
  108. total_counts=self.total_count,
  109. split_num=self.params.bin_num + 1)
  110. query_points_table = query_points_table.map(f)
  111. global_ranks = self.query_values(summary_table, query_points_table)
  112. n_iter = 0
  113. while n_iter < self.params.max_iter:
  114. self.set_suffix(n_iter)
  115. query_points_table = query_points_table.join(global_ranks, self.renew_query_points_table)
  116. is_converge = self.check_converge(query_points_table)
  117. if self.role == consts.GUEST:
  118. self.transfer_variable.is_converge.remote(is_converge, suffix=self.suffix)
  119. LOGGER.debug(f"n_iter: {n_iter}, converged: {is_converge}")
  120. if is_converge:
  121. break
  122. global_ranks = self.query_values(summary_table, query_points_table)
  123. n_iter += 1
  124. return query_points_table
  125. @staticmethod
  126. def renew_query_points_table(query_points, ranks):
  127. assert len(query_points) == len(ranks)
  128. new_array = []
  129. for idx, node in enumerate(query_points):
  130. rank = ranks[idx]
  131. if node.fixed:
  132. new_node = copy.deepcopy(node)
  133. elif rank - node.aim_rank > node.allow_error_rank:
  134. new_node = node.create_left_new()
  135. elif node.aim_rank - rank > node.allow_error_rank:
  136. new_node = node.create_right_new()
  137. else:
  138. new_node = copy.deepcopy(node)
  139. new_node.fixed = True
  140. new_node.last_rank = rank
  141. new_array.append(new_node)
  142. return new_array
  143. @staticmethod
  144. def check_converge(query_table):
  145. def is_all_fixed(node_array):
  146. fix_array = [n.fixed for n in node_array]
  147. return functools.reduce(operator.and_, fix_array)
  148. fix_table = query_table.mapValues(is_all_fixed)
  149. return fix_table.reduce(operator.and_)