virtual_summary_binning.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import bisect
  16. import numpy as np
  17. import functools
  18. from federatedml.feature.binning.quantile_tool import QuantileBinningTool
  19. from federatedml.feature.homo_feature_binning import homo_binning_base
  20. from federatedml.param.feature_binning_param import HomoFeatureBinningParam
  21. from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
  22. from federatedml.util import LOGGER
  23. from federatedml.util import consts
  24. class Server(homo_binning_base.Server):
  25. def __init__(self, params=None, abnormal_list=None):
  26. super().__init__(params, abnormal_list)
  27. def fit_split_points(self, data=None):
  28. if self.aggregator is None:
  29. self.aggregator = SecureAggregatorServer(True, communicate_match_suffix='virtual_summary_binning')
  30. self.get_total_count()
  31. self.get_min_max()
  32. self.get_missing_count()
  33. self.query_values()
  34. class Client(homo_binning_base.Client):
  35. def __init__(self, params: HomoFeatureBinningParam = None, abnormal_list=None, allow_duplicate=False):
  36. super().__init__(params, abnormal_list)
  37. self.allow_duplicate = allow_duplicate
  38. self.query_points = None
  39. self.global_ranks = None
  40. self.total_count = 0
  41. self.missing_count = 0
  42. def fit(self, data_inst):
  43. if self.bin_inner_param is None:
  44. self._setup_bin_inner_param(data_inst, self.params)
  45. self.total_count = self.get_total_count(data_inst)
  46. LOGGER.debug(f"abnormal_list: {self.abnormal_list}")
  47. quantile_tool = QuantileBinningTool(param_obj=self.params,
  48. abnormal_list=self.abnormal_list,
  49. allow_duplicate=self.allow_duplicate)
  50. quantile_tool.set_bin_inner_param(self.bin_inner_param)
  51. summary_table = quantile_tool.fit_summary(data_inst)
  52. self.get_min_max(data_inst)
  53. self.missing_count = self.get_missing_count(summary_table)
  54. self.query_points = self.init_query_points(summary_table.partitions,
  55. split_num=self.params.sample_bins)
  56. self.global_ranks = self.query_values(summary_table, self.query_points)
  57. # self.total_count = data_inst.count()
  58. def fit_split_points(self, data_instances):
  59. if self.aggregator is None:
  60. self.aggregator = SecureAggregatorClient(
  61. secure_aggregate=True,
  62. aggregate_type='sum',
  63. communicate_match_suffix='virtual_summary_binning')
  64. self.fit(data_instances)
  65. query_func = functools.partial(self._query, bin_num=self.bin_num,
  66. missing_count=self.missing_count,
  67. total_count=self.total_count)
  68. split_point_table = self.query_points.join(self.global_ranks, lambda x, y: (x, y))
  69. # split_point_table = self.query_points.join(self.global_ranks, query_func)
  70. split_point_table = split_point_table.map(query_func)
  71. split_points = dict(split_point_table.collect())
  72. for col_name, sps in split_points.items():
  73. self.bin_results.put_col_split_points(col_name, sps)
  74. # self._query(query_ranks)
  75. return self.bin_results.all_split_points
  76. def _query(self, feature_name, values, bin_num, missing_count, total_count):
  77. percent_value = 1.0 / bin_num
  78. # calculate the split points
  79. percentile_rate = [i * percent_value for i in range(1, bin_num)]
  80. percentile_rate.append(1.0)
  81. this_count = total_count - missing_count[feature_name]
  82. query_ranks = [int(x * this_count) for x in percentile_rate]
  83. query_points, global_ranks = values[0], values[1]
  84. query_values = [x.value for x in query_points]
  85. query_res = []
  86. # query_ranks = [max(0, x - missing_count[feature_name]) for x in query_ranks]
  87. for rank in query_ranks:
  88. idx = bisect.bisect_left(global_ranks, rank)
  89. if idx >= len(global_ranks) - 1:
  90. approx_value = query_values[-1]
  91. query_res.append(approx_value)
  92. else:
  93. if np.fabs(query_values[idx + 1] - query_values[idx]) < consts.FLOAT_ZERO:
  94. query_res.append(query_values[idx])
  95. elif np.fabs(global_ranks[idx + 1] - global_ranks[idx]) < consts.FLOAT_ZERO:
  96. query_res.append(query_values[idx])
  97. else:
  98. approx_value = query_values[idx] + (query_values[idx + 1] - query_values[idx]) * \
  99. ((rank - global_ranks[idx]) /
  100. (global_ranks[idx + 1] - global_ranks[idx]))
  101. query_res.append(approx_value)
  102. if not self.allow_duplicate:
  103. query_res = sorted(set(query_res))
  104. return feature_name, query_res