quantile_binning.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import functools
  19. import uuid
  20. from fate_arch.common.versions import get_eggroll_version
  21. from federatedml.feature.binning.base_binning import BaseBinning
  22. from federatedml.feature.binning.quantile_summaries import quantile_summary_factory
  23. from federatedml.param.feature_binning_param import FeatureBinningParam
  24. from federatedml.statistic import data_overview
  25. from federatedml.util import LOGGER
  26. from federatedml.util import consts
  27. import numpy as np
  28. class QuantileBinning(BaseBinning):
  29. """
  30. After quantile binning, the numbers of elements in each binning are equal.
  31. The result of this algorithm has the following deterministic bound:
  32. If the data_instances has N elements and if we request the quantile at probability `p` up to error
  33. `err`, then the algorithm will return a sample `x` from the data so that the *exact* rank
  34. of `x` is close to (p * N).
  35. More precisely,
  36. {{{
  37. floor((p - 2 * err) * N) <= rank(x) <= ceil((p + 2 * err) * N)
  38. }}}
  39. This method implements a variation of the Greenwald-Khanna algorithm (with some speed
  40. optimizations).
  41. """
  42. def __init__(self, params: FeatureBinningParam, abnormal_list=None, allow_duplicate=False):
  43. super(QuantileBinning, self).__init__(params, abnormal_list)
  44. self.summary_dict = None
  45. self.allow_duplicate = allow_duplicate
  46. def fit_split_points(self, data_instances):
  47. """
  48. Apply the binning method
  49. Parameters
  50. ----------
  51. data_instances : Table
  52. The input data
  53. Returns
  54. -------
  55. split_points : dict.
  56. Each value represent for the split points for a feature. The element in each row represent for
  57. the corresponding split point.
  58. e.g.
  59. split_points = {'x1': [0.1, 0.2, 0.3, 0.4 ...], # The first feature
  60. 'x2': [1, 2, 3, 4, ...], # The second feature
  61. ... # Other features
  62. }
  63. """
  64. header = data_overview.get_header(data_instances)
  65. anonymous_header = data_overview.get_anonymous_header(data_instances)
  66. LOGGER.debug("Header length: {}".format(len(header)))
  67. self._default_setting(header, anonymous_header)
  68. # self._init_cols(data_instances)
  69. percent_value = 1.0 / self.bin_num
  70. # calculate the split points
  71. percentile_rate = [i * percent_value for i in range(1, self.bin_num)]
  72. percentile_rate.append(1.0)
  73. is_sparse = data_overview.is_sparse_data(data_instances)
  74. self._fit_split_point(data_instances, is_sparse, percentile_rate)
  75. self.fit_category_features(data_instances)
  76. return self.bin_results.all_split_points
  77. @staticmethod
  78. def copy_merge(s1, s2):
  79. # new_s1 = copy.deepcopy(s1)
  80. return s1.merge(s2)
  81. def _fit_split_point(self, data_instances, is_sparse, percentile_rate):
  82. if self.summary_dict is None:
  83. f = functools.partial(self.feature_summary,
  84. params=self.params,
  85. abnormal_list=self.abnormal_list,
  86. cols_dict=self.bin_inner_param.bin_cols_map,
  87. header=self.header,
  88. is_sparse=is_sparse)
  89. # summary_dict_table = data_instances.mapReducePartitions(f, self.copy_merge)
  90. summary_dict_table = data_instances.mapReducePartitions(f, lambda s1, s2: s1.merge(s2))
  91. # summary_dict = dict(summary_dict.collect())
  92. if is_sparse:
  93. total_count = data_instances.count()
  94. summary_dict_table = summary_dict_table.mapValues(lambda x: x.set_total_count(total_count))
  95. self.summary_dict = summary_dict_table
  96. else:
  97. summary_dict_table = self.summary_dict
  98. f = functools.partial(self._get_split_points,
  99. allow_duplicate=self.allow_duplicate,
  100. percentile_rate=percentile_rate)
  101. summary_dict = dict(summary_dict_table.mapValues(f).collect())
  102. for col_name, split_point in summary_dict.items():
  103. self.bin_results.put_col_split_points(col_name, split_point)
  104. @staticmethod
  105. def _get_split_points(summary, percentile_rate, allow_duplicate):
  106. split_points = summary.query_percentile_rate_list(percentile_rate)
  107. if not allow_duplicate:
  108. return np.unique(split_points)
  109. else:
  110. return np.array(split_points)
  111. @staticmethod
  112. def feature_summary(data_iter, params, cols_dict, abnormal_list, header, is_sparse):
  113. summary_dict = {}
  114. summary_param = {'compress_thres': params.compress_thres,
  115. 'head_size': params.head_size,
  116. 'error': params.error,
  117. 'abnormal_list': abnormal_list}
  118. for col_name, col_index in cols_dict.items():
  119. quantile_summaries = quantile_summary_factory(is_sparse=is_sparse, param_dict=summary_param)
  120. summary_dict[col_name] = quantile_summaries
  121. _ = str(uuid.uuid1())
  122. for _, instant in data_iter:
  123. if not is_sparse:
  124. if type(instant).__name__ == 'Instance':
  125. features = instant.features
  126. else:
  127. features = instant
  128. for col_name, summary in summary_dict.items():
  129. col_index = cols_dict[col_name]
  130. summary.insert(features[col_index])
  131. else:
  132. data_generator = instant.features.get_all_data()
  133. for col_idx, col_value in data_generator:
  134. col_name = header[col_idx]
  135. if col_name not in cols_dict:
  136. continue
  137. summary = summary_dict[col_name]
  138. summary.insert(col_value)
  139. result = []
  140. for features_name, summary_obj in summary_dict.items():
  141. summary_obj.compress()
  142. # result.append(((_, features_name), summary_obj))
  143. result.append((features_name, summary_obj))
  144. return result
  145. @staticmethod
  146. def _query_split_points(summary, percent_rates):
  147. split_point = []
  148. for percent_rate in percent_rates:
  149. s_p = summary.query(percent_rate)
  150. if s_p not in split_point:
  151. split_point.append(s_p)
  152. return split_point
  153. @staticmethod
  154. def approxi_quantile(data_instances, params, cols_dict, abnormal_list, header, is_sparse):
  155. """
  156. Calculates each quantile information
  157. Parameters
  158. ----------
  159. data_instances : Table
  160. The input data
  161. cols_dict: dict
  162. Record key, value pairs where key is cols' name, and value is cols' index.
  163. params : FeatureBinningParam object,
  164. Parameters that user set.
  165. abnormal_list: list, default: None
  166. Specify which columns are abnormal so that will not static when traveling.
  167. header: list,
  168. Storing the header information.
  169. is_sparse: bool
  170. Specify whether data_instance is in sparse type
  171. Returns
  172. -------
  173. summary_dict: dict
  174. {'col_name1': summary1,
  175. 'col_name2': summary2,
  176. ...
  177. }
  178. """
  179. summary_dict = {}
  180. summary_param = {'compress_thres': params.compress_thres,
  181. 'head_size': params.head_size,
  182. 'error': params.error,
  183. 'abnormal_list': abnormal_list}
  184. for col_name, col_index in cols_dict.items():
  185. quantile_summaries = quantile_summary_factory(is_sparse=is_sparse, param_dict=summary_param)
  186. summary_dict[col_name] = quantile_summaries
  187. QuantileBinning.insert_datas(data_instances, summary_dict, cols_dict, header, is_sparse)
  188. for _, summary_obj in summary_dict.items():
  189. summary_obj.compress()
  190. return summary_dict
  191. @staticmethod
  192. def insert_datas(data_instances, summary_dict, cols_dict, header, is_sparse):
  193. for iter_key, instant in data_instances:
  194. if not is_sparse:
  195. if type(instant).__name__ == 'Instance':
  196. features = instant.features
  197. else:
  198. features = instant
  199. for col_name, summary in summary_dict.items():
  200. col_index = cols_dict[col_name]
  201. summary.insert(features[col_index])
  202. else:
  203. data_generator = instant.features.get_all_data()
  204. for col_idx, col_value in data_generator:
  205. col_name = header[col_idx]
  206. summary = summary_dict[col_name]
  207. summary.insert(col_value)
  208. @staticmethod
  209. def merge_summary_dict(s_dict1, s_dict2):
  210. if s_dict1 is None and s_dict2 is None:
  211. return None
  212. if s_dict1 is None:
  213. return s_dict2
  214. if s_dict2 is None:
  215. return s_dict1
  216. s_dict1 = copy.deepcopy(s_dict1)
  217. s_dict2 = copy.deepcopy(s_dict2)
  218. new_dict = {}
  219. for col_name, summary1 in s_dict1.items():
  220. summary2 = s_dict2.get(col_name)
  221. summary1.merge(summary2)
  222. new_dict[col_name] = summary1
  223. return new_dict
  224. @staticmethod
  225. def _query_quantile_points(col_name, summary, quantile_dict):
  226. quantile = quantile_dict.get(col_name)
  227. if quantile is not None:
  228. return col_name, summary.query(quantile)
  229. return col_name, quantile
  230. def query_quantile_point(self, query_points, col_names=None):
  231. if self.summary_dict is None:
  232. raise RuntimeError("Bin object should be fit before query quantile points")
  233. if col_names is None:
  234. col_names = self.bin_inner_param.bin_names
  235. summary_dict = self.summary_dict
  236. if isinstance(query_points, (int, float)):
  237. query_dict = {}
  238. for col_name in col_names:
  239. query_dict[col_name] = query_points
  240. elif isinstance(query_points, dict):
  241. query_dict = query_points
  242. else:
  243. raise ValueError("query_points has wrong type, should be a float, int or dict")
  244. f = functools.partial(self._query_quantile_points,
  245. quantile_dict=query_dict)
  246. result = dict(summary_dict.map(f).collect())
  247. return result
  248. class QuantileBinningTool(QuantileBinning):
  249. """
  250. Use for quantile binning data directly.
  251. """
  252. def __init__(self, bin_nums=consts.G_BIN_NUM, param_obj: FeatureBinningParam = None,
  253. abnormal_list=None, allow_duplicate=False):
  254. if param_obj is None:
  255. param_obj = FeatureBinningParam(bin_num=bin_nums)
  256. super().__init__(params=param_obj, abnormal_list=abnormal_list, allow_duplicate=allow_duplicate)