heap.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import math
  18. import numpy as np
  19. from federatedml.feature.binning.optimal_binning.bucket_info import Bucket
  20. from federatedml.param.feature_binning_param import OptimalBinningParam
  21. from federatedml.util import LOGGER
  22. class HeapNode(object):
  23. def __init__(self):
  24. self.left_bucket: Bucket = None
  25. self.right_bucket: Bucket = None
  26. self.event_count = 0
  27. self.non_event_count = 0
  28. self.score = None
  29. def cal_score(self):
  30. raise NotImplementedError("Should not call here")
  31. @property
  32. def total_count(self):
  33. return self.event_count + self.non_event_count
  34. class IvHeapNode(HeapNode):
  35. def __init__(self, adjustment_factor=0.5):
  36. super().__init__()
  37. self.adjustment_factor = adjustment_factor
  38. self.event_total = 0
  39. self.non_event_total = 0
  40. def cal_score(self):
  41. """
  42. IV = ∑(py_i - pn_i ) * WOE
  43. where py_i is event_rate, pn_i is non_event_rate
  44. WOE = log(non_event_rate / event_rate)
  45. """
  46. self.event_count = self.left_bucket.event_count + self.right_bucket.event_count
  47. self.non_event_count = self.left_bucket.non_event_count + self.right_bucket.non_event_count
  48. if self.total_count == 0:
  49. self.score = -math.inf
  50. return
  51. # if self.left_bucket.left_bound != math.inf and self.right_bucket.right_bound != -math.inf:
  52. # if (self.left_bucket.left_bound <= self.right_bucket.right_bound):
  53. # self.score = -math.inf
  54. # return
  55. self.event_total = self.left_bucket.event_total
  56. self.non_event_total = self.left_bucket.non_event_total
  57. if self.event_count == 0 or self.non_event_count == 0:
  58. event_rate = 1.0 * (self.event_count + self.adjustment_factor) / max(self.event_total, 1)
  59. non_event_rate = 1.0 * (self.non_event_count + self.adjustment_factor) / max(self.non_event_total, 1)
  60. else:
  61. event_rate = 1.0 * self.event_count / max(self.event_total, 1)
  62. non_event_rate = 1.0 * self.non_event_count / max(self.non_event_total, 1)
  63. merge_woe = math.log(event_rate / non_event_rate)
  64. merge_iv = (event_rate - non_event_rate) * merge_woe
  65. self.score = self.left_bucket.iv + self.right_bucket.iv - merge_iv
  66. class GiniHeapNode(HeapNode):
  67. def cal_score(self):
  68. """
  69. gini = 1 - ∑(p_i^2 ) = 1 -(event / total)^2 - (nonevent / total)^2
  70. """
  71. self.event_count = self.left_bucket.event_count + self.right_bucket.event_count
  72. self.non_event_count = self.left_bucket.non_event_count + self.right_bucket.non_event_count
  73. if self.total_count == 0:
  74. self.score = -math.inf
  75. return
  76. # if self.total_count == 0 or self.left_bucket.left_bound == self.right_bucket.right_bound:
  77. # self.score = -math.inf
  78. # return
  79. merged_gini = 1 - (1.0 * self.event_count / self.total_count) ** 2 - \
  80. (1.0 * self.non_event_count / self.total_count) ** 2
  81. self.score = merged_gini - self.left_bucket.gini - self.right_bucket.gini
  82. class ChiSquareHeapNode(HeapNode):
  83. def cal_score(self):
  84. """
  85. X^2 = ∑∑(A_ij - E_ij )^2 / E_ij
  86. where E_ij = (N_i / N) * C_j. N is total count of merged bucket, N_i is the total count of ith bucket
  87. and C_j is the count of jth label in merged bucket.
  88. A_ij is number of jth label in ith bucket.
  89. """
  90. self.event_count = self.left_bucket.event_count + self.right_bucket.event_count
  91. self.non_event_count = self.left_bucket.non_event_count + self.right_bucket.non_event_count
  92. if self.total_count == 0:
  93. self.score = -math.inf
  94. return
  95. c1 = self.left_bucket.event_count + self.right_bucket.event_count
  96. c0 = self.left_bucket.non_event_count + self.right_bucket.non_event_count
  97. if c1 == 0 or c0 == 0:
  98. self.score = - math.inf
  99. return
  100. e_left_1 = (self.left_bucket.total_count / self.total_count) * c1
  101. e_left_0 = (self.left_bucket.total_count / self.total_count) * c0
  102. e_right_1 = (self.right_bucket.total_count / self.total_count) * c1
  103. e_right_0 = (self.right_bucket.total_count / self.total_count) * c0
  104. chi_square = np.square(self.left_bucket.event_count - e_left_1) / e_left_1 + \
  105. np.square(self.left_bucket.non_event_count - e_left_0) / e_left_0 + \
  106. np.square(self.right_bucket.event_count - e_right_1) / e_right_1 + \
  107. np.square(self.right_bucket.non_event_count - e_right_0) / e_right_0
  108. LOGGER.debug("chi_sqaure: {}".format(chi_square))
  109. self.score = chi_square
  110. def heap_node_factory(optimal_param: OptimalBinningParam, left_bucket=None, right_bucket=None):
  111. metric_method = optimal_param.metric_method
  112. if metric_method == 'iv':
  113. node = IvHeapNode(adjustment_factor=optimal_param.adjustment_factor)
  114. elif metric_method == 'gini':
  115. node = GiniHeapNode()
  116. elif metric_method == 'chi_square':
  117. node = ChiSquareHeapNode()
  118. else:
  119. raise ValueError("metric_method: {} cannot recognized".format(metric_method))
  120. if left_bucket is not None:
  121. node.left_bucket = left_bucket
  122. if right_bucket is not None:
  123. node.right_bucket = right_bucket
  124. if (left_bucket and right_bucket) is not None:
  125. node.cal_score()
  126. else:
  127. LOGGER.warning("In heap factory, left_bucket is {}, right bucket is {}, not all of them has been assign".format(
  128. left_bucket, right_bucket
  129. ))
  130. return node
  131. class MinHeap(object):
  132. def __init__(self):
  133. self.size = 0
  134. self.node_list = []
  135. @property
  136. def is_empty(self):
  137. return self.size <= 0
  138. def insert(self, heap_node: HeapNode):
  139. self.size += 1
  140. self.node_list.append(heap_node)
  141. self._move_up(self.size - 1)
  142. def remove_empty_node(self, removed_bucket_id):
  143. for n_id, node in enumerate(self.node_list):
  144. if node.left_bucket.idx == removed_bucket_id or node.right_bucket.idx == removed_bucket_id:
  145. self.delete_index_k(n_id)
  146. def delete_index_k(self, k):
  147. if k >= self.size:
  148. return
  149. if k == self.size - 1:
  150. self.node_list.pop()
  151. self.size -= 1
  152. else:
  153. self.node_list[k] = self.node_list[self.size - 1]
  154. self.node_list.pop()
  155. self.size -= 1
  156. if k == 0:
  157. self._move_down(k)
  158. else:
  159. parent_idx = self._get_parent_index(k)
  160. if self.node_list[parent_idx].score < self.node_list[k].score:
  161. self._move_down(k)
  162. else:
  163. self._move_up(k)
  164. def pop(self):
  165. min_node = self.node_list[0] if not self.is_empty else None
  166. if min_node is not None:
  167. self.node_list[0] = self.node_list[self.size - 1]
  168. self.node_list.pop()
  169. self.size -= 1
  170. self._move_down(0)
  171. return min_node
  172. def _switch_node(self, idx_1, idx_2):
  173. if idx_1 >= self.size or idx_2 >= self.size:
  174. return
  175. self.node_list[idx_1], self.node_list[idx_2] = self.node_list[idx_2], self.node_list[idx_1]
  176. @staticmethod
  177. def _get_parent_index(index):
  178. if index == 0:
  179. return None
  180. parent_index = (index - 1) / 2
  181. return int(parent_index) if parent_index >= 0 else None
  182. def _get_left_child_idx(self, idx):
  183. child_index = (2 * idx) + 1
  184. return child_index if child_index < self.size else None
  185. def _get_right_child_idx(self, idx):
  186. child_index = (2 * idx) + 2
  187. return child_index if child_index < self.size else None
  188. def _move_down(self, curt_idx):
  189. if curt_idx >= self.size:
  190. return
  191. min_idx = curt_idx
  192. while True:
  193. left_child_idx = self._get_left_child_idx(curt_idx)
  194. right_child_idx = self._get_right_child_idx(curt_idx)
  195. if left_child_idx is not None and self.node_list[left_child_idx].score < self.node_list[curt_idx].score:
  196. min_idx = left_child_idx
  197. if right_child_idx is not None and self.node_list[right_child_idx].score < self.node_list[min_idx].score:
  198. min_idx = right_child_idx
  199. if min_idx != curt_idx:
  200. self._switch_node(curt_idx, min_idx)
  201. curt_idx = min_idx
  202. else:
  203. break
  204. def _move_up(self, curt_idx):
  205. if curt_idx >= self.size:
  206. return
  207. while True:
  208. parent_idx = self._get_parent_index(curt_idx)
  209. if parent_idx is None:
  210. break
  211. if self.node_list[curt_idx].score < self.node_list[parent_idx].score:
  212. self._switch_node(curt_idx, parent_idx)
  213. curt_idx = parent_idx
  214. else:
  215. break