#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import numpy as np from federatedml.feature.binning.optimal_binning.bucket_info import Bucket from federatedml.param.feature_binning_param import OptimalBinningParam from federatedml.util import LOGGER class HeapNode(object): def __init__(self): self.left_bucket: Bucket = None self.right_bucket: Bucket = None self.event_count = 0 self.non_event_count = 0 self.score = None def cal_score(self): raise NotImplementedError("Should not call here") @property def total_count(self): return self.event_count + self.non_event_count class IvHeapNode(HeapNode): def __init__(self, adjustment_factor=0.5): super().__init__() self.adjustment_factor = adjustment_factor self.event_total = 0 self.non_event_total = 0 def cal_score(self): """ IV = ∑(py_i - pn_i ) * WOE where py_i is event_rate, pn_i is non_event_rate WOE = log(non_event_rate / event_rate) """ self.event_count = self.left_bucket.event_count + self.right_bucket.event_count self.non_event_count = self.left_bucket.non_event_count + self.right_bucket.non_event_count if self.total_count == 0: self.score = -math.inf return # if self.left_bucket.left_bound != math.inf and self.right_bucket.right_bound != -math.inf: # if (self.left_bucket.left_bound <= self.right_bucket.right_bound): # self.score = -math.inf # return self.event_total = self.left_bucket.event_total self.non_event_total = self.left_bucket.non_event_total if self.event_count == 0 or self.non_event_count == 0: event_rate = 1.0 * (self.event_count + self.adjustment_factor) / max(self.event_total, 1) non_event_rate = 1.0 * (self.non_event_count + self.adjustment_factor) / max(self.non_event_total, 1) else: event_rate = 1.0 * self.event_count / max(self.event_total, 1) non_event_rate = 1.0 * self.non_event_count / max(self.non_event_total, 1) merge_woe = math.log(event_rate / non_event_rate) merge_iv = (event_rate - non_event_rate) * merge_woe self.score = self.left_bucket.iv + self.right_bucket.iv - merge_iv class GiniHeapNode(HeapNode): def cal_score(self): """ gini = 1 - ∑(p_i^2 ) = 1 -(event / total)^2 - (nonevent / total)^2 """ self.event_count = self.left_bucket.event_count + self.right_bucket.event_count self.non_event_count = self.left_bucket.non_event_count + self.right_bucket.non_event_count if self.total_count == 0: self.score = -math.inf return # if self.total_count == 0 or self.left_bucket.left_bound == self.right_bucket.right_bound: # self.score = -math.inf # return merged_gini = 1 - (1.0 * self.event_count / self.total_count) ** 2 - \ (1.0 * self.non_event_count / self.total_count) ** 2 self.score = merged_gini - self.left_bucket.gini - self.right_bucket.gini class ChiSquareHeapNode(HeapNode): def cal_score(self): """ X^2 = ∑∑(A_ij - E_ij )^2 / E_ij where E_ij = (N_i / N) * C_j. N is total count of merged bucket, N_i is the total count of ith bucket and C_j is the count of jth label in merged bucket. A_ij is number of jth label in ith bucket. """ self.event_count = self.left_bucket.event_count + self.right_bucket.event_count self.non_event_count = self.left_bucket.non_event_count + self.right_bucket.non_event_count if self.total_count == 0: self.score = -math.inf return c1 = self.left_bucket.event_count + self.right_bucket.event_count c0 = self.left_bucket.non_event_count + self.right_bucket.non_event_count if c1 == 0 or c0 == 0: self.score = - math.inf return e_left_1 = (self.left_bucket.total_count / self.total_count) * c1 e_left_0 = (self.left_bucket.total_count / self.total_count) * c0 e_right_1 = (self.right_bucket.total_count / self.total_count) * c1 e_right_0 = (self.right_bucket.total_count / self.total_count) * c0 chi_square = np.square(self.left_bucket.event_count - e_left_1) / e_left_1 + \ np.square(self.left_bucket.non_event_count - e_left_0) / e_left_0 + \ np.square(self.right_bucket.event_count - e_right_1) / e_right_1 + \ np.square(self.right_bucket.non_event_count - e_right_0) / e_right_0 LOGGER.debug("chi_sqaure: {}".format(chi_square)) self.score = chi_square def heap_node_factory(optimal_param: OptimalBinningParam, left_bucket=None, right_bucket=None): metric_method = optimal_param.metric_method if metric_method == 'iv': node = IvHeapNode(adjustment_factor=optimal_param.adjustment_factor) elif metric_method == 'gini': node = GiniHeapNode() elif metric_method == 'chi_square': node = ChiSquareHeapNode() else: raise ValueError("metric_method: {} cannot recognized".format(metric_method)) if left_bucket is not None: node.left_bucket = left_bucket if right_bucket is not None: node.right_bucket = right_bucket if (left_bucket and right_bucket) is not None: node.cal_score() else: LOGGER.warning("In heap factory, left_bucket is {}, right bucket is {}, not all of them has been assign".format( left_bucket, right_bucket )) return node class MinHeap(object): def __init__(self): self.size = 0 self.node_list = [] @property def is_empty(self): return self.size <= 0 def insert(self, heap_node: HeapNode): self.size += 1 self.node_list.append(heap_node) self._move_up(self.size - 1) def remove_empty_node(self, removed_bucket_id): for n_id, node in enumerate(self.node_list): if node.left_bucket.idx == removed_bucket_id or node.right_bucket.idx == removed_bucket_id: self.delete_index_k(n_id) def delete_index_k(self, k): if k >= self.size: return if k == self.size - 1: self.node_list.pop() self.size -= 1 else: self.node_list[k] = self.node_list[self.size - 1] self.node_list.pop() self.size -= 1 if k == 0: self._move_down(k) else: parent_idx = self._get_parent_index(k) if self.node_list[parent_idx].score < self.node_list[k].score: self._move_down(k) else: self._move_up(k) def pop(self): min_node = self.node_list[0] if not self.is_empty else None if min_node is not None: self.node_list[0] = self.node_list[self.size - 1] self.node_list.pop() self.size -= 1 self._move_down(0) return min_node def _switch_node(self, idx_1, idx_2): if idx_1 >= self.size or idx_2 >= self.size: return self.node_list[idx_1], self.node_list[idx_2] = self.node_list[idx_2], self.node_list[idx_1] @staticmethod def _get_parent_index(index): if index == 0: return None parent_index = (index - 1) / 2 return int(parent_index) if parent_index >= 0 else None def _get_left_child_idx(self, idx): child_index = (2 * idx) + 1 return child_index if child_index < self.size else None def _get_right_child_idx(self, idx): child_index = (2 * idx) + 2 return child_index if child_index < self.size else None def _move_down(self, curt_idx): if curt_idx >= self.size: return min_idx = curt_idx while True: left_child_idx = self._get_left_child_idx(curt_idx) right_child_idx = self._get_right_child_idx(curt_idx) if left_child_idx is not None and self.node_list[left_child_idx].score < self.node_list[curt_idx].score: min_idx = left_child_idx if right_child_idx is not None and self.node_list[right_child_idx].score < self.node_list[min_idx].score: min_idx = right_child_idx if min_idx != curt_idx: self._switch_node(curt_idx, min_idx) curt_idx = min_idx else: break def _move_up(self, curt_idx): if curt_idx >= self.size: return while True: parent_idx = self._get_parent_index(curt_idx) if parent_idx is None: break if self.node_list[curt_idx].score < self.node_list[parent_idx].score: self._switch_node(curt_idx, parent_idx) curt_idx = parent_idx else: break