optimal_binning.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import functools
  19. import math
  20. import operator
  21. import numpy as np
  22. from fate_arch.session import computing_session as session
  23. from federatedml.feature.binning.base_binning import BaseBinning
  24. from federatedml.feature.binning.bucket_binning import BucketBinning
  25. from federatedml.feature.binning.optimal_binning import bucket_info
  26. from federatedml.feature.binning.optimal_binning import heap
  27. from federatedml.feature.binning.quantile_tool import QuantileBinningTool
  28. from federatedml.param.feature_binning_param import HeteroFeatureBinningParam, OptimalBinningParam
  29. from federatedml.statistic import data_overview
  30. from federatedml.statistic import statics
  31. from federatedml.util import LOGGER
  32. from federatedml.util import consts
  33. class OptimalBinning(BaseBinning):
  34. def __init__(self, params, abnormal_list=None):
  35. super().__init__(params, abnormal_list)
  36. """The following lines work only in fitting process"""
  37. if isinstance(params, HeteroFeatureBinningParam):
  38. self.optimal_param = params.optimal_binning_param
  39. self.optimal_param.adjustment_factor = params.adjustment_factor
  40. self.optimal_param.max_bin = params.bin_num
  41. if math.ceil(1.0 / self.optimal_param.max_bin_pct) > self.optimal_param.max_bin:
  42. raise ValueError("Arguments logical error, ceil(1.0/max_bin_pct) "
  43. "should be smaller or equal than bin_num")
  44. self.adjustment_factor = params.adjustment_factor
  45. self.event_total = None
  46. self.non_event_total = None
  47. self.bucket_lists = {}
  48. def fit_split_points(self, data_instances):
  49. header = data_overview.get_header(data_instances)
  50. anonymous_header = data_overview.get_anonymous_header(data_instances)
  51. self._default_setting(header, anonymous_header)
  52. if (self.event_total and self.non_event_total) is None:
  53. self.event_total, self.non_event_total = self.get_histogram(data_instances)
  54. # LOGGER.debug("In fit split points, event_total: {}, non_event_total: {}".format(self.event_total,
  55. # self.non_event_total))
  56. bucket_table = self.init_bucket(data_instances)
  57. sample_count = data_instances.count()
  58. self.fit_buckets(bucket_table, sample_count)
  59. self.fit_category_features(data_instances)
  60. return self.bin_results.all_split_points
  61. def fit_buckets(self, bucket_table, sample_count):
  62. if self.optimal_param.metric_method in ['iv', 'gini', 'chi_square']:
  63. optimal_binning_method = functools.partial(self.merge_optimal_binning,
  64. optimal_param=self.optimal_param,
  65. sample_count=sample_count)
  66. result_bucket = bucket_table.mapValues(optimal_binning_method)
  67. for col_name, (min_heap, bucket_list, non_mixture_num, small_size_num) in result_bucket.collect():
  68. split_points = np.unique([bucket.right_bound for bucket in bucket_list]).tolist()
  69. self.bin_results.put_col_split_points(col_name, split_points)
  70. metric_array = [node.score for node in min_heap.node_list]
  71. self.bin_results.put_col_optimal_metric_array(col_name, metric_array)
  72. # LOGGER.debug(f"column {col_name}, split_points: {split_points}, metric array: {metric_array}")
  73. self.bucket_lists[col_name] = bucket_list
  74. else:
  75. optimal_binning_method = functools.partial(self.split_optimal_binning,
  76. optimal_param=self.optimal_param,
  77. sample_count=sample_count)
  78. result_bucket = bucket_table.mapValues(optimal_binning_method)
  79. for col_name, (bucket_list, non_mixture_num, small_size_num, res_ks_array) in result_bucket.collect():
  80. split_points = np.unique([bucket.right_bound for bucket in bucket_list]).tolist()
  81. self.bin_results.put_col_split_points(col_name, split_points)
  82. self.bin_results.put_col_optimal_metric_array(col_name, res_ks_array)
  83. # LOGGER.debug(f"column {col_name}, split_points: {split_points}, metric array: {res_ks_array}")
  84. self.bucket_lists[col_name] = bucket_list
  85. return result_bucket
  86. def init_bucket(self, data_instances):
  87. header = data_overview.get_header(data_instances)
  88. anonymous_header = data_overview.get_anonymous_header(data_instances)
  89. self._default_setting(header, anonymous_header)
  90. init_bucket_param = copy.deepcopy(self.params)
  91. init_bucket_param.bin_num = self.optimal_param.init_bin_nums
  92. if self.optimal_param.init_bucket_method == consts.QUANTILE:
  93. init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param, allow_duplicate=False)
  94. else:
  95. init_binning_obj = BucketBinning(params=init_bucket_param)
  96. init_binning_obj.set_bin_inner_param(self.bin_inner_param)
  97. init_split_points = init_binning_obj.fit_split_points(data_instances)
  98. is_sparse = data_overview.is_sparse_data(data_instances)
  99. bucket_dict = dict()
  100. for col_name, sps in init_split_points.items():
  101. bucket_list = []
  102. for idx, sp in enumerate(sps):
  103. bucket = bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp)
  104. if idx == 0:
  105. bucket.left_bound = -math.inf
  106. bucket.set_left_neighbor(None)
  107. else:
  108. bucket.left_bound = sps[idx - 1]
  109. bucket.event_total = self.event_total
  110. bucket.non_event_total = self.non_event_total
  111. bucket_list.append(bucket)
  112. bucket_list[-1].set_right_neighbor(None)
  113. bucket_dict[col_name] = bucket_list
  114. # LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, "
  115. # f"length of list: {len(bucket_list)}")
  116. convert_func = functools.partial(self.convert_data_to_bucket,
  117. split_points=init_split_points,
  118. headers=self.header,
  119. bucket_dict=copy.deepcopy(bucket_dict),
  120. is_sparse=is_sparse,
  121. get_bin_num_func=self.get_bin_num)
  122. bucket_table = data_instances.mapReducePartitions(convert_func, self.merge_bucket_list)
  123. return bucket_table
  124. @staticmethod
  125. def get_histogram(data_instances):
  126. static_obj = statics.MultivariateStatisticalSummary(data_instances, cols_index=-1)
  127. label_historgram = static_obj.get_label_histogram()
  128. event_total = label_historgram.get(1, 0)
  129. non_event_total = label_historgram.get(0, 0)
  130. # if event_total == 0 or non_event_total == 0:
  131. # LOGGER.warning(f"event_total or non_event_total might have errors, event_total: {event_total},"
  132. # f" non_event_total: {non_event_total}")
  133. return event_total, non_event_total
  134. @staticmethod
  135. def assign_histogram(bucket_list, event_total, non_event_total):
  136. for bucket in bucket_list:
  137. bucket.event_total = event_total
  138. bucket.non_event_total = non_event_total
  139. return bucket_list
  140. @staticmethod
  141. def merge_bucket_list(list1, list2):
  142. if len(list1) != len(list2):
  143. raise AssertionError("In merge bucket list, len of two lists are not equal")
  144. result = []
  145. for idx, b1 in enumerate(list1):
  146. b2 = list2[idx]
  147. result.append(b1.merge(b2))
  148. return result
  149. @staticmethod
  150. def convert_data_to_bucket(data_iter, split_points, headers, bucket_dict,
  151. is_sparse, get_bin_num_func):
  152. for data_key, instance in data_iter:
  153. label = instance.label
  154. if not is_sparse:
  155. if type(instance).__name__ == 'Instance':
  156. features = instance.features
  157. else:
  158. features = instance
  159. data_generator = enumerate(features)
  160. else:
  161. data_generator = instance.features.get_all_data()
  162. for idx, col_value in data_generator:
  163. col_name = headers[idx]
  164. if col_name not in split_points:
  165. continue
  166. col_split_points = split_points[col_name]
  167. bin_num = get_bin_num_func(col_value, col_split_points)
  168. bucket = bucket_dict[col_name][bin_num]
  169. bucket.add(label, col_value)
  170. result = []
  171. for col_name, bucket_list in bucket_dict.items():
  172. result.append((col_name, bucket_list))
  173. return result
  174. @staticmethod
  175. def merge_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count):
  176. max_item_num = math.floor(optimal_param.max_bin_pct * sample_count)
  177. min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count)
  178. bucket_dict = {idx: bucket for idx, bucket in enumerate(bucket_list)}
  179. final_max_bin = optimal_param.max_bin
  180. # LOGGER.debug("Get in merge optimal binning, sample_count: {}, max_item_num: {}, min_item_num: {},"
  181. # "final_max_bin: {}".format(sample_count, max_item_num, min_item_num, final_max_bin))
  182. min_heap = heap.MinHeap()
  183. def _add_heap_nodes(constraint=None):
  184. # LOGGER.debug(f"Add heap nodes, constraint: {}, dict_length: {}".format(constraint, len(bucket_dict)))
  185. this_non_mixture_num = 0
  186. this_small_size_num = 0
  187. # Make bucket satisfy mixture condition
  188. for i in range(len(bucket_dict)):
  189. left_bucket = bucket_dict[i]
  190. right_bucket = bucket_dict.get(left_bucket.right_neighbor_idx)
  191. if left_bucket.right_neighbor_idx == i:
  192. raise RuntimeError("left_bucket's right neighbor == itself")
  193. if not left_bucket.is_mixed:
  194. this_non_mixture_num += 1
  195. if left_bucket.total_count < min_item_num:
  196. this_small_size_num += 1
  197. if right_bucket is None:
  198. continue
  199. # Violate maximum items constraint
  200. if left_bucket.total_count + right_bucket.total_count > max_item_num:
  201. continue
  202. if constraint == 'mixture':
  203. if left_bucket.is_mixed or right_bucket.is_mixed:
  204. continue
  205. elif constraint == 'single_mixture':
  206. if left_bucket.is_mixed and right_bucket.is_mixed:
  207. continue
  208. elif constraint == 'small_size':
  209. if left_bucket.total_count >= min_item_num or right_bucket.total_count >= min_item_num:
  210. continue
  211. elif constraint == 'single_small_size':
  212. if left_bucket.total_count >= min_item_num and right_bucket.total_count >= min_item_num:
  213. continue
  214. heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=right_bucket)
  215. min_heap.insert(heap_node)
  216. return min_heap, this_non_mixture_num, this_small_size_num
  217. def _update_bucket_info(b_dict):
  218. """
  219. update bucket information
  220. """
  221. order_dict = dict()
  222. for bucket_idx, item in b_dict.items():
  223. order_dict[bucket_idx] = item.left_bound
  224. sorted_order_dict = sorted(order_dict.items(), key=operator.itemgetter(1))
  225. start_idx = 0
  226. for item in sorted_order_dict:
  227. bucket_idx = item[0]
  228. if start_idx == bucket_idx:
  229. start_idx += 1
  230. continue
  231. b_dict[start_idx] = b_dict[bucket_idx]
  232. b_dict[start_idx].idx = start_idx
  233. start_idx += 1
  234. del b_dict[bucket_idx]
  235. bucket_num = len(b_dict)
  236. for i in range(bucket_num):
  237. if i == 0:
  238. b_dict[i].set_left_neighbor(None)
  239. b_dict[i].set_right_neighbor(i + 1)
  240. else:
  241. b_dict[i].set_left_neighbor(i - 1)
  242. b_dict[i].set_right_neighbor(i + 1)
  243. b_dict[bucket_num - 1].set_right_neighbor(None)
  244. return b_dict
  245. def _merge_heap(constraint=None, aim_var=0):
  246. next_id = max(bucket_dict.keys()) + 1
  247. while aim_var > 0 and not min_heap.is_empty:
  248. min_node = min_heap.pop()
  249. left_bucket = min_node.left_bucket
  250. right_bucket = min_node.right_bucket
  251. # Some buckets may be already merged
  252. if left_bucket.idx not in bucket_dict or right_bucket.idx not in bucket_dict:
  253. continue
  254. new_bucket = bucket_info.Bucket(idx=next_id, adjustment_factor=optimal_param.adjustment_factor)
  255. new_bucket = _init_new_bucket(new_bucket, min_node)
  256. bucket_dict[next_id] = new_bucket
  257. del bucket_dict[left_bucket.idx]
  258. del bucket_dict[right_bucket.idx]
  259. min_heap.remove_empty_node(left_bucket.idx)
  260. min_heap.remove_empty_node(right_bucket.idx)
  261. aim_var = _aim_vars_decrease(constraint, new_bucket, left_bucket, right_bucket, aim_var)
  262. _add_node_from_new_bucket(new_bucket, constraint)
  263. next_id += 1
  264. return min_heap, aim_var
  265. def _add_node_from_new_bucket(new_bucket: bucket_info.Bucket, constraint):
  266. left_bucket = bucket_dict.get(new_bucket.left_neighbor_idx)
  267. right_bucket = bucket_dict.get(new_bucket.right_neighbor_idx)
  268. if constraint == 'mixture':
  269. if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
  270. if not left_bucket.is_mixed and not new_bucket.is_mixed:
  271. heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
  272. right_bucket=new_bucket)
  273. min_heap.insert(heap_node)
  274. if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
  275. if not right_bucket.is_mixed and not new_bucket.is_mixed:
  276. heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
  277. right_bucket=right_bucket)
  278. min_heap.insert(heap_node)
  279. elif constraint == 'single_mixture':
  280. if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
  281. if not (left_bucket.is_mixed and new_bucket.is_mixed):
  282. heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
  283. right_bucket=new_bucket)
  284. min_heap.insert(heap_node)
  285. if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
  286. if not (right_bucket.is_mixed and new_bucket.is_mixed):
  287. heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
  288. right_bucket=right_bucket)
  289. min_heap.insert(heap_node)
  290. elif constraint == 'small_size':
  291. if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
  292. if left_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num:
  293. heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
  294. right_bucket=new_bucket)
  295. min_heap.insert(heap_node)
  296. if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
  297. if right_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num:
  298. heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
  299. right_bucket=right_bucket)
  300. min_heap.insert(heap_node)
  301. elif constraint == 'single_small_size':
  302. if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
  303. if left_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num:
  304. heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
  305. right_bucket=new_bucket)
  306. min_heap.insert(heap_node)
  307. if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
  308. if right_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num:
  309. heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
  310. right_bucket=right_bucket)
  311. min_heap.insert(heap_node)
  312. else:
  313. if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
  314. heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
  315. right_bucket=new_bucket)
  316. min_heap.insert(heap_node)
  317. if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
  318. heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
  319. right_bucket=right_bucket)
  320. min_heap.insert(heap_node)
  321. def _init_new_bucket(new_bucket: bucket_info.Bucket, min_node: heap.HeapNode):
  322. new_bucket.left_bound = min_node.left_bucket.left_bound
  323. new_bucket.right_bound = min_node.right_bucket.right_bound
  324. new_bucket.left_neighbor_idx = min_node.left_bucket.left_neighbor_idx
  325. new_bucket.right_neighbor_idx = min_node.right_bucket.right_neighbor_idx
  326. new_bucket.event_count = min_node.left_bucket.event_count + min_node.right_bucket.event_count
  327. new_bucket.non_event_count = min_node.left_bucket.non_event_count + min_node.right_bucket.non_event_count
  328. new_bucket.event_total = min_node.left_bucket.event_total
  329. new_bucket.non_event_total = min_node.left_bucket.non_event_total
  330. left_neighbor_bucket = bucket_dict.get(new_bucket.left_neighbor_idx)
  331. if left_neighbor_bucket is not None:
  332. left_neighbor_bucket.right_neighbor_idx = new_bucket.idx
  333. right_neighbor_bucket = bucket_dict.get(new_bucket.right_neighbor_idx)
  334. if right_neighbor_bucket is not None:
  335. right_neighbor_bucket.left_neighbor_idx = new_bucket.idx
  336. return new_bucket
  337. def _aim_vars_decrease(constraint, new_bucket: bucket_info.Bucket, left_bucket, right_bucket, aim_var):
  338. if constraint in ['mixture', 'single_mixture']:
  339. if not left_bucket.is_mixed:
  340. aim_var -= 1
  341. if not right_bucket.is_mixed:
  342. aim_var -= 1
  343. if not new_bucket.is_mixed:
  344. aim_var += 1
  345. elif constraint in ['small_size', 'single_small_size']:
  346. if left_bucket.total_count < min_item_num:
  347. aim_var -= 1
  348. if right_bucket.total_count < min_item_num:
  349. aim_var -= 1
  350. if new_bucket.total_count < min_item_num:
  351. aim_var += 1
  352. else:
  353. aim_var = len(bucket_dict) - final_max_bin
  354. return aim_var
  355. if optimal_param.mixture:
  356. LOGGER.debug(f"Before mixture add, dict length: {len(bucket_dict)}")
  357. min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='mixture')
  358. min_heap, non_mixture_num = _merge_heap(constraint='mixture', aim_var=non_mixture_num)
  359. bucket_dict = _update_bucket_info(bucket_dict)
  360. min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_mixture')
  361. min_heap, non_mixture_num = _merge_heap(constraint='single_mixture', aim_var=non_mixture_num)
  362. LOGGER.debug(f"After mixture merge, min_heap size: {min_heap.size}, non_mixture_num: {non_mixture_num}")
  363. bucket_dict = _update_bucket_info(bucket_dict)
  364. LOGGER.debug(f"Before small_size add, dict length: {len(bucket_dict)}")
  365. min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='small_size')
  366. min_heap, small_size_num = _merge_heap(constraint='small_size', aim_var=small_size_num)
  367. bucket_dict = _update_bucket_info(bucket_dict)
  368. min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_small_size')
  369. min_heap, small_size_num = _merge_heap(constraint='single_small_size', aim_var=small_size_num)
  370. bucket_dict = _update_bucket_info(bucket_dict)
  371. # LOGGER.debug(f"Before add, dict length: {len(bucket_dict)}")
  372. min_heap, non_mixture_num, small_size_num = _add_heap_nodes()
  373. # LOGGER.debug("After normal add, small_size: {}, min_heap size: {}".format(small_size_num, min_heap.size))
  374. min_heap, total_bucket_num = _merge_heap(aim_var=len(bucket_dict) - final_max_bin)
  375. # LOGGER.debug("After normal merge, min_heap size: {}".format(min_heap.size))
  376. non_mixture_num = 0
  377. small_size_num = 0
  378. for i, bucket in bucket_dict.items():
  379. if not bucket.is_mixed:
  380. non_mixture_num += 1
  381. if bucket.total_count < min_item_num:
  382. small_size_num += 1
  383. bucket_res = list(bucket_dict.values())
  384. bucket_res = sorted(bucket_res, key=lambda bucket: bucket.left_bound)
  385. # LOGGER.debug("Before return, dict length: {}".format(len(bucket_dict)))
  386. # LOGGER.debug(f"Before return, min heap node list length: {len(min_heap.node_list)}")
  387. return min_heap, bucket_res, non_mixture_num, small_size_num
  388. @staticmethod
  389. def split_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count):
  390. min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count)
  391. final_max_bin = optimal_param.max_bin
  392. def _compute_ks(start_idx, end_idx):
  393. acc_event = []
  394. acc_non_event = []
  395. curt_event_total = 0
  396. curt_non_event_total = 0
  397. for bucket in bucket_list[start_idx: end_idx]:
  398. acc_event.append(bucket.event_count + curt_event_total)
  399. curt_event_total += bucket.event_count
  400. acc_non_event.append(bucket.non_event_count + curt_non_event_total)
  401. curt_non_event_total += bucket.non_event_count
  402. if curt_event_total == 0 or curt_non_event_total == 0:
  403. return None, None, None
  404. acc_event_rate = [x / curt_event_total for x in acc_event]
  405. acc_non_event_rate = [x / curt_non_event_total for x in acc_non_event]
  406. ks_list = [math.fabs(eve - non_eve) for eve, non_eve in zip(acc_event_rate, acc_non_event_rate)]
  407. if max(ks_list) == 0:
  408. best_index = len(ks_list) // 2
  409. else:
  410. best_index = ks_list.index(max(ks_list))
  411. left_event = acc_event[best_index]
  412. right_event = curt_event_total - left_event
  413. left_non_event = acc_non_event[best_index]
  414. right_non_event = curt_non_event_total - left_non_event
  415. left_total = left_event + left_non_event
  416. right_total = right_event + right_non_event
  417. if left_total < min_item_num or right_total < min_item_num:
  418. best_index = len(ks_list) // 2
  419. left_event = acc_event[best_index]
  420. right_event = curt_event_total - left_event
  421. left_non_event = acc_non_event[best_index]
  422. right_non_event = curt_non_event_total - left_non_event
  423. left_total = left_event + left_non_event
  424. right_total = right_event + right_non_event
  425. best_ks = ks_list[best_index]
  426. res_dict = {
  427. 'left_event': left_event,
  428. 'right_event': right_event,
  429. 'left_non_event': left_non_event,
  430. 'right_non_event': right_non_event,
  431. 'left_total': left_total,
  432. 'right_total': right_total,
  433. 'left_is_mixed': left_event > 0 and left_non_event > 0,
  434. 'right_is_mixed': right_event > 0 and right_non_event > 0
  435. }
  436. return best_ks, start + best_index, res_dict
  437. def _merge_buckets(start_idx, end_idx, bucket_idx):
  438. res_bucket = copy.deepcopy(bucket_list[start_idx])
  439. res_bucket.idx = bucket_idx
  440. for bucket in bucket_list[start_idx + 1: end_idx]:
  441. res_bucket = res_bucket.merge(bucket)
  442. return res_bucket
  443. res_split_index = []
  444. res_split_ks = {}
  445. to_split_pair = [(0, len(bucket_list))]
  446. # iteratively split
  447. while len(to_split_pair) > 0:
  448. if len(res_split_index) >= final_max_bin - 1:
  449. break
  450. start, end = to_split_pair.pop(0)
  451. if start >= end:
  452. continue
  453. best_ks, best_index, res_dict = _compute_ks(start, end)
  454. if best_ks is None:
  455. continue
  456. if optimal_param.mixture:
  457. if not (res_dict.get('left_is_mixed') and res_dict.get('right_is_mixed')):
  458. continue
  459. if res_dict.get('left_total') < min_item_num or res_dict.get('right_total') < min_item_num:
  460. continue
  461. res_split_index.append(best_index + 1)
  462. res_split_ks[best_index + 1] = best_ks
  463. if res_dict.get('right_total') > res_dict.get('left_total'):
  464. to_split_pair.append((best_index + 1, end))
  465. to_split_pair.append((start, best_index + 1))
  466. else:
  467. to_split_pair.append((start, best_index + 1))
  468. to_split_pair.append((best_index + 1, end))
  469. # LOGGER.debug("to_split_pair: {}".format(to_split_pair))
  470. if len(res_split_index) == 0:
  471. LOGGER.warning("Best ks optimal binning fail to split. Take middle split point instead")
  472. res_split_index.append(len(bucket_list) // 2)
  473. res_split_index = sorted(res_split_index)
  474. res_ks = []
  475. if res_split_ks:
  476. res_ks = [res_split_ks[idx] for idx in res_split_index]
  477. # last bin
  478. # res_ks.append(0.0)
  479. res_split_index.append(len(bucket_list))
  480. start = 0
  481. bucket_res = []
  482. non_mixture_num = 0
  483. small_size_num = 0
  484. for bucket_idx, end in enumerate(res_split_index):
  485. new_bucket = _merge_buckets(start, end, bucket_idx)
  486. bucket_res.append(new_bucket)
  487. if not new_bucket.is_mixed:
  488. non_mixture_num += 1
  489. if new_bucket.total_count < min_item_num:
  490. small_size_num += 1
  491. start = end
  492. return bucket_res, non_mixture_num, small_size_num, res_ks
  493. def bin_sum_to_bucket_list(self, bin_sum, partitions):
  494. """
  495. Convert bin sum result, which typically get from host, to bucket list
  496. Parameters
  497. ----------
  498. bin_sum : dict
  499. {'x1': [[event_count, non_event_count], [event_count, non_event_count] ... ],
  500. 'x2': [[event_count, non_event_count], [event_count, non_event_count] ... ],
  501. ...
  502. }
  503. partitions: int
  504. Indicate partitions for created table.
  505. Returns
  506. -------
  507. A Table whose keys are feature names and values are bucket lists
  508. """
  509. bucket_dict = dict()
  510. for col_name, bin_res_list in bin_sum.items():
  511. bucket_list = []
  512. for b_idx in range(len(bin_res_list)):
  513. bucket = bucket_info.Bucket(b_idx, self.adjustment_factor)
  514. if b_idx == 0:
  515. bucket.set_left_neighbor(None)
  516. if b_idx == len(bin_res_list) - 1:
  517. bucket.set_right_neighbor(None)
  518. bucket.event_count = bin_res_list[b_idx][0]
  519. bucket.non_event_count = bin_res_list[b_idx][1]
  520. bucket.left_bound = b_idx - 1
  521. bucket.right_bound = b_idx
  522. bucket.event_total = self.event_total
  523. bucket.non_event_total = self.non_event_total
  524. bucket_list.append(bucket)
  525. bucket_dict[col_name] = bucket_list
  526. result = []
  527. for col_name, bucket_list in bucket_dict.items():
  528. result.append((col_name, bucket_list))
  529. result_table = session.parallelize(result,
  530. include_key=True,
  531. partition=partitions)
  532. return result_table