bucket_info.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import math
  18. class Bucket(object):
  19. def __init__(self, idx=-1, adjustment_factor=0.5, right_bound=-math.inf):
  20. self.idx = idx
  21. self.left_bound = math.inf
  22. self.right_bound = right_bound
  23. self.left_neighbor_idx = idx - 1
  24. self.right_neighbor_idx = idx + 1
  25. self.event_count = 0
  26. self.non_event_count = 0
  27. self.adjustment_factor = adjustment_factor
  28. self.event_total = None
  29. self.non_event_total = None
  30. def set_left_neighbor(self, left_idx):
  31. self.left_neighbor_idx = left_idx
  32. def set_right_neighbor(self, right_idx):
  33. self.right_neighbor_idx = right_idx
  34. @property
  35. def is_mixed(self):
  36. return self.event_count > 0 and self.non_event_count > 0
  37. @property
  38. def total_count(self):
  39. return self.event_count + self.non_event_count
  40. def merge(self, other):
  41. if other is None:
  42. return
  43. if other.left_bound < self.left_bound:
  44. self.left_bound = other.left_bound
  45. if other.right_bound > self.right_bound:
  46. self.right_bound = other.right_bound
  47. self.event_count += other.event_count
  48. self.non_event_count += other.non_event_count
  49. return self
  50. def add(self, label, value):
  51. if label == 1:
  52. self.event_count += 1
  53. else:
  54. self.non_event_count += 1
  55. if value < self.left_bound:
  56. self.left_bound = value
  57. if value > self.right_bound:
  58. self.right_bound = value
  59. @property
  60. def iv(self):
  61. if self.event_total is None or self.non_event_total is None:
  62. raise AssertionError("Bucket's event_total or non_event_total has not been assigned")
  63. # only have EVENT records or Non-Event records
  64. if self.event_count == 0 or self.non_event_count == 0:
  65. event_rate = 1.0 * (self.event_count + self.adjustment_factor) / max(self.event_total, 1)
  66. non_event_rate = 1.0 * (self.non_event_count + self.adjustment_factor) / max(self.non_event_total, 1)
  67. else:
  68. event_rate = 1.0 * self.event_count / max(self.event_total, 1)
  69. non_event_rate = 1.0 * self.non_event_count / max(self.non_event_total, 1)
  70. woe = math.log(non_event_rate / event_rate)
  71. return (non_event_rate - event_rate) * woe
  72. @property
  73. def gini(self):
  74. if self.total_count == 0:
  75. return 0
  76. return 1 - (1.0 * self.event_count / self.total_count) ** 2 - \
  77. (1.0 * self.non_event_count / self.total_count) ** 2