123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import functools
- import math
- import operator
- import numpy as np
- from fate_arch.session import computing_session as session
- from federatedml.feature.binning.base_binning import BaseBinning
- from federatedml.feature.binning.bucket_binning import BucketBinning
- from federatedml.feature.binning.optimal_binning import bucket_info
- from federatedml.feature.binning.optimal_binning import heap
- from federatedml.feature.binning.quantile_tool import QuantileBinningTool
- from federatedml.param.feature_binning_param import HeteroFeatureBinningParam, OptimalBinningParam
- from federatedml.statistic import data_overview
- from federatedml.statistic import statics
- from federatedml.util import LOGGER
- from federatedml.util import consts
- class OptimalBinning(BaseBinning):
- def __init__(self, params, abnormal_list=None):
- super().__init__(params, abnormal_list)
- """The following lines work only in fitting process"""
- if isinstance(params, HeteroFeatureBinningParam):
- self.optimal_param = params.optimal_binning_param
- self.optimal_param.adjustment_factor = params.adjustment_factor
- self.optimal_param.max_bin = params.bin_num
- if math.ceil(1.0 / self.optimal_param.max_bin_pct) > self.optimal_param.max_bin:
- raise ValueError("Arguments logical error, ceil(1.0/max_bin_pct) "
- "should be smaller or equal than bin_num")
- self.adjustment_factor = params.adjustment_factor
- self.event_total = None
- self.non_event_total = None
- self.bucket_lists = {}
- def fit_split_points(self, data_instances):
- header = data_overview.get_header(data_instances)
- anonymous_header = data_overview.get_anonymous_header(data_instances)
- self._default_setting(header, anonymous_header)
- if (self.event_total and self.non_event_total) is None:
- self.event_total, self.non_event_total = self.get_histogram(data_instances)
- # LOGGER.debug("In fit split points, event_total: {}, non_event_total: {}".format(self.event_total,
- # self.non_event_total))
- bucket_table = self.init_bucket(data_instances)
- sample_count = data_instances.count()
- self.fit_buckets(bucket_table, sample_count)
- self.fit_category_features(data_instances)
- return self.bin_results.all_split_points
- def fit_buckets(self, bucket_table, sample_count):
- if self.optimal_param.metric_method in ['iv', 'gini', 'chi_square']:
- optimal_binning_method = functools.partial(self.merge_optimal_binning,
- optimal_param=self.optimal_param,
- sample_count=sample_count)
- result_bucket = bucket_table.mapValues(optimal_binning_method)
- for col_name, (min_heap, bucket_list, non_mixture_num, small_size_num) in result_bucket.collect():
- split_points = np.unique([bucket.right_bound for bucket in bucket_list]).tolist()
- self.bin_results.put_col_split_points(col_name, split_points)
- metric_array = [node.score for node in min_heap.node_list]
- self.bin_results.put_col_optimal_metric_array(col_name, metric_array)
- # LOGGER.debug(f"column {col_name}, split_points: {split_points}, metric array: {metric_array}")
- self.bucket_lists[col_name] = bucket_list
- else:
- optimal_binning_method = functools.partial(self.split_optimal_binning,
- optimal_param=self.optimal_param,
- sample_count=sample_count)
- result_bucket = bucket_table.mapValues(optimal_binning_method)
- for col_name, (bucket_list, non_mixture_num, small_size_num, res_ks_array) in result_bucket.collect():
- split_points = np.unique([bucket.right_bound for bucket in bucket_list]).tolist()
- self.bin_results.put_col_split_points(col_name, split_points)
- self.bin_results.put_col_optimal_metric_array(col_name, res_ks_array)
- # LOGGER.debug(f"column {col_name}, split_points: {split_points}, metric array: {res_ks_array}")
- self.bucket_lists[col_name] = bucket_list
- return result_bucket
- def init_bucket(self, data_instances):
- header = data_overview.get_header(data_instances)
- anonymous_header = data_overview.get_anonymous_header(data_instances)
- self._default_setting(header, anonymous_header)
- init_bucket_param = copy.deepcopy(self.params)
- init_bucket_param.bin_num = self.optimal_param.init_bin_nums
- if self.optimal_param.init_bucket_method == consts.QUANTILE:
- init_binning_obj = QuantileBinningTool(param_obj=init_bucket_param, allow_duplicate=False)
- else:
- init_binning_obj = BucketBinning(params=init_bucket_param)
- init_binning_obj.set_bin_inner_param(self.bin_inner_param)
- init_split_points = init_binning_obj.fit_split_points(data_instances)
- is_sparse = data_overview.is_sparse_data(data_instances)
- bucket_dict = dict()
- for col_name, sps in init_split_points.items():
- bucket_list = []
- for idx, sp in enumerate(sps):
- bucket = bucket_info.Bucket(idx, self.adjustment_factor, right_bound=sp)
- if idx == 0:
- bucket.left_bound = -math.inf
- bucket.set_left_neighbor(None)
- else:
- bucket.left_bound = sps[idx - 1]
- bucket.event_total = self.event_total
- bucket.non_event_total = self.non_event_total
- bucket_list.append(bucket)
- bucket_list[-1].set_right_neighbor(None)
- bucket_dict[col_name] = bucket_list
- # LOGGER.debug(f"col_name: {col_name}, length of sps: {len(sps)}, "
- # f"length of list: {len(bucket_list)}")
- convert_func = functools.partial(self.convert_data_to_bucket,
- split_points=init_split_points,
- headers=self.header,
- bucket_dict=copy.deepcopy(bucket_dict),
- is_sparse=is_sparse,
- get_bin_num_func=self.get_bin_num)
- bucket_table = data_instances.mapReducePartitions(convert_func, self.merge_bucket_list)
- return bucket_table
- @staticmethod
- def get_histogram(data_instances):
- static_obj = statics.MultivariateStatisticalSummary(data_instances, cols_index=-1)
- label_historgram = static_obj.get_label_histogram()
- event_total = label_historgram.get(1, 0)
- non_event_total = label_historgram.get(0, 0)
- # if event_total == 0 or non_event_total == 0:
- # LOGGER.warning(f"event_total or non_event_total might have errors, event_total: {event_total},"
- # f" non_event_total: {non_event_total}")
- return event_total, non_event_total
- @staticmethod
- def assign_histogram(bucket_list, event_total, non_event_total):
- for bucket in bucket_list:
- bucket.event_total = event_total
- bucket.non_event_total = non_event_total
- return bucket_list
- @staticmethod
- def merge_bucket_list(list1, list2):
- if len(list1) != len(list2):
- raise AssertionError("In merge bucket list, len of two lists are not equal")
- result = []
- for idx, b1 in enumerate(list1):
- b2 = list2[idx]
- result.append(b1.merge(b2))
- return result
- @staticmethod
- def convert_data_to_bucket(data_iter, split_points, headers, bucket_dict,
- is_sparse, get_bin_num_func):
- for data_key, instance in data_iter:
- label = instance.label
- if not is_sparse:
- if type(instance).__name__ == 'Instance':
- features = instance.features
- else:
- features = instance
- data_generator = enumerate(features)
- else:
- data_generator = instance.features.get_all_data()
- for idx, col_value in data_generator:
- col_name = headers[idx]
- if col_name not in split_points:
- continue
- col_split_points = split_points[col_name]
- bin_num = get_bin_num_func(col_value, col_split_points)
- bucket = bucket_dict[col_name][bin_num]
- bucket.add(label, col_value)
- result = []
- for col_name, bucket_list in bucket_dict.items():
- result.append((col_name, bucket_list))
- return result
- @staticmethod
- def merge_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count):
- max_item_num = math.floor(optimal_param.max_bin_pct * sample_count)
- min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count)
- bucket_dict = {idx: bucket for idx, bucket in enumerate(bucket_list)}
- final_max_bin = optimal_param.max_bin
- # LOGGER.debug("Get in merge optimal binning, sample_count: {}, max_item_num: {}, min_item_num: {},"
- # "final_max_bin: {}".format(sample_count, max_item_num, min_item_num, final_max_bin))
- min_heap = heap.MinHeap()
- def _add_heap_nodes(constraint=None):
- # LOGGER.debug(f"Add heap nodes, constraint: {}, dict_length: {}".format(constraint, len(bucket_dict)))
- this_non_mixture_num = 0
- this_small_size_num = 0
- # Make bucket satisfy mixture condition
- for i in range(len(bucket_dict)):
- left_bucket = bucket_dict[i]
- right_bucket = bucket_dict.get(left_bucket.right_neighbor_idx)
- if left_bucket.right_neighbor_idx == i:
- raise RuntimeError("left_bucket's right neighbor == itself")
- if not left_bucket.is_mixed:
- this_non_mixture_num += 1
- if left_bucket.total_count < min_item_num:
- this_small_size_num += 1
- if right_bucket is None:
- continue
- # Violate maximum items constraint
- if left_bucket.total_count + right_bucket.total_count > max_item_num:
- continue
- if constraint == 'mixture':
- if left_bucket.is_mixed or right_bucket.is_mixed:
- continue
- elif constraint == 'single_mixture':
- if left_bucket.is_mixed and right_bucket.is_mixed:
- continue
- elif constraint == 'small_size':
- if left_bucket.total_count >= min_item_num or right_bucket.total_count >= min_item_num:
- continue
- elif constraint == 'single_small_size':
- if left_bucket.total_count >= min_item_num and right_bucket.total_count >= min_item_num:
- continue
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket, right_bucket=right_bucket)
- min_heap.insert(heap_node)
- return min_heap, this_non_mixture_num, this_small_size_num
- def _update_bucket_info(b_dict):
- """
- update bucket information
- """
- order_dict = dict()
- for bucket_idx, item in b_dict.items():
- order_dict[bucket_idx] = item.left_bound
- sorted_order_dict = sorted(order_dict.items(), key=operator.itemgetter(1))
- start_idx = 0
- for item in sorted_order_dict:
- bucket_idx = item[0]
- if start_idx == bucket_idx:
- start_idx += 1
- continue
- b_dict[start_idx] = b_dict[bucket_idx]
- b_dict[start_idx].idx = start_idx
- start_idx += 1
- del b_dict[bucket_idx]
- bucket_num = len(b_dict)
- for i in range(bucket_num):
- if i == 0:
- b_dict[i].set_left_neighbor(None)
- b_dict[i].set_right_neighbor(i + 1)
- else:
- b_dict[i].set_left_neighbor(i - 1)
- b_dict[i].set_right_neighbor(i + 1)
- b_dict[bucket_num - 1].set_right_neighbor(None)
- return b_dict
- def _merge_heap(constraint=None, aim_var=0):
- next_id = max(bucket_dict.keys()) + 1
- while aim_var > 0 and not min_heap.is_empty:
- min_node = min_heap.pop()
- left_bucket = min_node.left_bucket
- right_bucket = min_node.right_bucket
- # Some buckets may be already merged
- if left_bucket.idx not in bucket_dict or right_bucket.idx not in bucket_dict:
- continue
- new_bucket = bucket_info.Bucket(idx=next_id, adjustment_factor=optimal_param.adjustment_factor)
- new_bucket = _init_new_bucket(new_bucket, min_node)
- bucket_dict[next_id] = new_bucket
- del bucket_dict[left_bucket.idx]
- del bucket_dict[right_bucket.idx]
- min_heap.remove_empty_node(left_bucket.idx)
- min_heap.remove_empty_node(right_bucket.idx)
- aim_var = _aim_vars_decrease(constraint, new_bucket, left_bucket, right_bucket, aim_var)
- _add_node_from_new_bucket(new_bucket, constraint)
- next_id += 1
- return min_heap, aim_var
- def _add_node_from_new_bucket(new_bucket: bucket_info.Bucket, constraint):
- left_bucket = bucket_dict.get(new_bucket.left_neighbor_idx)
- right_bucket = bucket_dict.get(new_bucket.right_neighbor_idx)
- if constraint == 'mixture':
- if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
- if not left_bucket.is_mixed and not new_bucket.is_mixed:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
- right_bucket=new_bucket)
- min_heap.insert(heap_node)
- if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
- if not right_bucket.is_mixed and not new_bucket.is_mixed:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
- right_bucket=right_bucket)
- min_heap.insert(heap_node)
- elif constraint == 'single_mixture':
- if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
- if not (left_bucket.is_mixed and new_bucket.is_mixed):
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
- right_bucket=new_bucket)
- min_heap.insert(heap_node)
- if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
- if not (right_bucket.is_mixed and new_bucket.is_mixed):
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
- right_bucket=right_bucket)
- min_heap.insert(heap_node)
- elif constraint == 'small_size':
- if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
- if left_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
- right_bucket=new_bucket)
- min_heap.insert(heap_node)
- if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
- if right_bucket.total_count < min_item_num and new_bucket.total_count < min_item_num:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
- right_bucket=right_bucket)
- min_heap.insert(heap_node)
- elif constraint == 'single_small_size':
- if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
- if left_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
- right_bucket=new_bucket)
- min_heap.insert(heap_node)
- if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
- if right_bucket.total_count < min_item_num or new_bucket.total_count < min_item_num:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
- right_bucket=right_bucket)
- min_heap.insert(heap_node)
- else:
- if left_bucket is not None and left_bucket.total_count + new_bucket.total_count <= max_item_num:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=left_bucket,
- right_bucket=new_bucket)
- min_heap.insert(heap_node)
- if right_bucket is not None and right_bucket.total_count + new_bucket.total_count <= max_item_num:
- heap_node = heap.heap_node_factory(optimal_param, left_bucket=new_bucket,
- right_bucket=right_bucket)
- min_heap.insert(heap_node)
- def _init_new_bucket(new_bucket: bucket_info.Bucket, min_node: heap.HeapNode):
- new_bucket.left_bound = min_node.left_bucket.left_bound
- new_bucket.right_bound = min_node.right_bucket.right_bound
- new_bucket.left_neighbor_idx = min_node.left_bucket.left_neighbor_idx
- new_bucket.right_neighbor_idx = min_node.right_bucket.right_neighbor_idx
- new_bucket.event_count = min_node.left_bucket.event_count + min_node.right_bucket.event_count
- new_bucket.non_event_count = min_node.left_bucket.non_event_count + min_node.right_bucket.non_event_count
- new_bucket.event_total = min_node.left_bucket.event_total
- new_bucket.non_event_total = min_node.left_bucket.non_event_total
- left_neighbor_bucket = bucket_dict.get(new_bucket.left_neighbor_idx)
- if left_neighbor_bucket is not None:
- left_neighbor_bucket.right_neighbor_idx = new_bucket.idx
- right_neighbor_bucket = bucket_dict.get(new_bucket.right_neighbor_idx)
- if right_neighbor_bucket is not None:
- right_neighbor_bucket.left_neighbor_idx = new_bucket.idx
- return new_bucket
- def _aim_vars_decrease(constraint, new_bucket: bucket_info.Bucket, left_bucket, right_bucket, aim_var):
- if constraint in ['mixture', 'single_mixture']:
- if not left_bucket.is_mixed:
- aim_var -= 1
- if not right_bucket.is_mixed:
- aim_var -= 1
- if not new_bucket.is_mixed:
- aim_var += 1
- elif constraint in ['small_size', 'single_small_size']:
- if left_bucket.total_count < min_item_num:
- aim_var -= 1
- if right_bucket.total_count < min_item_num:
- aim_var -= 1
- if new_bucket.total_count < min_item_num:
- aim_var += 1
- else:
- aim_var = len(bucket_dict) - final_max_bin
- return aim_var
- if optimal_param.mixture:
- LOGGER.debug(f"Before mixture add, dict length: {len(bucket_dict)}")
- min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='mixture')
- min_heap, non_mixture_num = _merge_heap(constraint='mixture', aim_var=non_mixture_num)
- bucket_dict = _update_bucket_info(bucket_dict)
- min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_mixture')
- min_heap, non_mixture_num = _merge_heap(constraint='single_mixture', aim_var=non_mixture_num)
- LOGGER.debug(f"After mixture merge, min_heap size: {min_heap.size}, non_mixture_num: {non_mixture_num}")
- bucket_dict = _update_bucket_info(bucket_dict)
- LOGGER.debug(f"Before small_size add, dict length: {len(bucket_dict)}")
- min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='small_size')
- min_heap, small_size_num = _merge_heap(constraint='small_size', aim_var=small_size_num)
- bucket_dict = _update_bucket_info(bucket_dict)
- min_heap, non_mixture_num, small_size_num = _add_heap_nodes(constraint='single_small_size')
- min_heap, small_size_num = _merge_heap(constraint='single_small_size', aim_var=small_size_num)
- bucket_dict = _update_bucket_info(bucket_dict)
- # LOGGER.debug(f"Before add, dict length: {len(bucket_dict)}")
- min_heap, non_mixture_num, small_size_num = _add_heap_nodes()
- # LOGGER.debug("After normal add, small_size: {}, min_heap size: {}".format(small_size_num, min_heap.size))
- min_heap, total_bucket_num = _merge_heap(aim_var=len(bucket_dict) - final_max_bin)
- # LOGGER.debug("After normal merge, min_heap size: {}".format(min_heap.size))
- non_mixture_num = 0
- small_size_num = 0
- for i, bucket in bucket_dict.items():
- if not bucket.is_mixed:
- non_mixture_num += 1
- if bucket.total_count < min_item_num:
- small_size_num += 1
- bucket_res = list(bucket_dict.values())
- bucket_res = sorted(bucket_res, key=lambda bucket: bucket.left_bound)
- # LOGGER.debug("Before return, dict length: {}".format(len(bucket_dict)))
- # LOGGER.debug(f"Before return, min heap node list length: {len(min_heap.node_list)}")
- return min_heap, bucket_res, non_mixture_num, small_size_num
- @staticmethod
- def split_optimal_binning(bucket_list, optimal_param: OptimalBinningParam, sample_count):
- min_item_num = math.ceil(optimal_param.min_bin_pct * sample_count)
- final_max_bin = optimal_param.max_bin
- def _compute_ks(start_idx, end_idx):
- acc_event = []
- acc_non_event = []
- curt_event_total = 0
- curt_non_event_total = 0
- for bucket in bucket_list[start_idx: end_idx]:
- acc_event.append(bucket.event_count + curt_event_total)
- curt_event_total += bucket.event_count
- acc_non_event.append(bucket.non_event_count + curt_non_event_total)
- curt_non_event_total += bucket.non_event_count
- if curt_event_total == 0 or curt_non_event_total == 0:
- return None, None, None
- acc_event_rate = [x / curt_event_total for x in acc_event]
- acc_non_event_rate = [x / curt_non_event_total for x in acc_non_event]
- ks_list = [math.fabs(eve - non_eve) for eve, non_eve in zip(acc_event_rate, acc_non_event_rate)]
- if max(ks_list) == 0:
- best_index = len(ks_list) // 2
- else:
- best_index = ks_list.index(max(ks_list))
- left_event = acc_event[best_index]
- right_event = curt_event_total - left_event
- left_non_event = acc_non_event[best_index]
- right_non_event = curt_non_event_total - left_non_event
- left_total = left_event + left_non_event
- right_total = right_event + right_non_event
- if left_total < min_item_num or right_total < min_item_num:
- best_index = len(ks_list) // 2
- left_event = acc_event[best_index]
- right_event = curt_event_total - left_event
- left_non_event = acc_non_event[best_index]
- right_non_event = curt_non_event_total - left_non_event
- left_total = left_event + left_non_event
- right_total = right_event + right_non_event
- best_ks = ks_list[best_index]
- res_dict = {
- 'left_event': left_event,
- 'right_event': right_event,
- 'left_non_event': left_non_event,
- 'right_non_event': right_non_event,
- 'left_total': left_total,
- 'right_total': right_total,
- 'left_is_mixed': left_event > 0 and left_non_event > 0,
- 'right_is_mixed': right_event > 0 and right_non_event > 0
- }
- return best_ks, start + best_index, res_dict
- def _merge_buckets(start_idx, end_idx, bucket_idx):
- res_bucket = copy.deepcopy(bucket_list[start_idx])
- res_bucket.idx = bucket_idx
- for bucket in bucket_list[start_idx + 1: end_idx]:
- res_bucket = res_bucket.merge(bucket)
- return res_bucket
- res_split_index = []
- res_split_ks = {}
- to_split_pair = [(0, len(bucket_list))]
- # iteratively split
- while len(to_split_pair) > 0:
- if len(res_split_index) >= final_max_bin - 1:
- break
- start, end = to_split_pair.pop(0)
- if start >= end:
- continue
- best_ks, best_index, res_dict = _compute_ks(start, end)
- if best_ks is None:
- continue
- if optimal_param.mixture:
- if not (res_dict.get('left_is_mixed') and res_dict.get('right_is_mixed')):
- continue
- if res_dict.get('left_total') < min_item_num or res_dict.get('right_total') < min_item_num:
- continue
- res_split_index.append(best_index + 1)
- res_split_ks[best_index + 1] = best_ks
- if res_dict.get('right_total') > res_dict.get('left_total'):
- to_split_pair.append((best_index + 1, end))
- to_split_pair.append((start, best_index + 1))
- else:
- to_split_pair.append((start, best_index + 1))
- to_split_pair.append((best_index + 1, end))
- # LOGGER.debug("to_split_pair: {}".format(to_split_pair))
- if len(res_split_index) == 0:
- LOGGER.warning("Best ks optimal binning fail to split. Take middle split point instead")
- res_split_index.append(len(bucket_list) // 2)
- res_split_index = sorted(res_split_index)
- res_ks = []
- if res_split_ks:
- res_ks = [res_split_ks[idx] for idx in res_split_index]
- # last bin
- # res_ks.append(0.0)
- res_split_index.append(len(bucket_list))
- start = 0
- bucket_res = []
- non_mixture_num = 0
- small_size_num = 0
- for bucket_idx, end in enumerate(res_split_index):
- new_bucket = _merge_buckets(start, end, bucket_idx)
- bucket_res.append(new_bucket)
- if not new_bucket.is_mixed:
- non_mixture_num += 1
- if new_bucket.total_count < min_item_num:
- small_size_num += 1
- start = end
- return bucket_res, non_mixture_num, small_size_num, res_ks
- def bin_sum_to_bucket_list(self, bin_sum, partitions):
- """
- Convert bin sum result, which typically get from host, to bucket list
- Parameters
- ----------
- bin_sum : dict
- {'x1': [[event_count, non_event_count], [event_count, non_event_count] ... ],
- 'x2': [[event_count, non_event_count], [event_count, non_event_count] ... ],
- ...
- }
- partitions: int
- Indicate partitions for created table.
- Returns
- -------
- A Table whose keys are feature names and values are bucket lists
- """
- bucket_dict = dict()
- for col_name, bin_res_list in bin_sum.items():
- bucket_list = []
- for b_idx in range(len(bin_res_list)):
- bucket = bucket_info.Bucket(b_idx, self.adjustment_factor)
- if b_idx == 0:
- bucket.set_left_neighbor(None)
- if b_idx == len(bin_res_list) - 1:
- bucket.set_right_neighbor(None)
- bucket.event_count = bin_res_list[b_idx][0]
- bucket.non_event_count = bin_res_list[b_idx][1]
- bucket.left_bound = b_idx - 1
- bucket.right_bound = b_idx
- bucket.event_total = self.event_total
- bucket.non_event_total = self.non_event_total
- bucket_list.append(bucket)
- bucket_dict[col_name] = bucket_list
- result = []
- for col_name, bucket_list in bucket_dict.items():
- result.append((col_name, bucket_list))
- result_table = session.parallelize(result,
- include_key=True,
- partition=partitions)
- return result_table
|