bin_result.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.protobuf.generated import feature_binning_param_pb2
  19. from federatedml.util import LOGGER
  20. class BinColResults(object):
  21. def __init__(self, woe_array=(), iv_array=(), event_count_array=(), non_event_count_array=(),
  22. event_rate_array=(), non_event_rate_array=(), iv=None, optimal_metric_array=()):
  23. self.woe_array = list(woe_array)
  24. self.iv_array = list(iv_array)
  25. self.event_count_array = list(event_count_array)
  26. self.non_event_count_array = list(non_event_count_array)
  27. self.event_rate_array = list(event_rate_array)
  28. self.non_event_rate_array = list(non_event_rate_array)
  29. self.split_points = None
  30. if iv is None:
  31. iv = 0
  32. for idx, woe in enumerate(self.woe_array):
  33. non_event_rate = non_event_count_array[idx]
  34. event_rate = event_rate_array[idx]
  35. iv += (non_event_rate - event_rate) * woe
  36. self.iv = iv
  37. self._bin_anonymous = None
  38. self.optimal_metric_array = list(optimal_metric_array)
  39. @property
  40. def bin_anonymous(self):
  41. if self.split_points is None or len(self.split_points) == 0:
  42. return []
  43. if self._bin_anonymous is None:
  44. return ["bin_" + str(i) for i in range(len(self.split_points))]
  45. return self._bin_anonymous
  46. @bin_anonymous.setter
  47. def bin_anonymous(self, x):
  48. self._bin_anonymous = x
  49. def set_split_points(self, split_points):
  50. self.split_points = split_points
  51. def set_optimal_metric(self, metric_array):
  52. self.optimal_metric_array = metric_array
  53. def get_split_points(self):
  54. return np.array(self.split_points)
  55. @property
  56. def is_woe_monotonic(self):
  57. """
  58. Check the woe is monotonic or not
  59. """
  60. woe_array = self.woe_array
  61. if len(woe_array) <= 1:
  62. return True
  63. is_increasing = all(x <= y for x, y in zip(woe_array, woe_array[1:]))
  64. is_decreasing = all(x >= y for x, y in zip(woe_array, woe_array[1:]))
  65. return is_increasing or is_decreasing
  66. @property
  67. def bin_nums(self):
  68. return len(self.woe_array)
  69. def result_dict(self):
  70. save_dict = self.__dict__
  71. save_dict['is_woe_monotonic'] = self.is_woe_monotonic
  72. save_dict['bin_nums'] = self.bin_nums
  73. return save_dict
  74. def reconstruct(self, iv_obj):
  75. self.woe_array = list(iv_obj.woe_array)
  76. self.iv_array = list(iv_obj.iv_array)
  77. self.event_count_array = list(iv_obj.event_count_array)
  78. self.non_event_count_array = list(iv_obj.non_event_count_array)
  79. self.event_rate_array = list(iv_obj.event_rate_array)
  80. self.non_event_rate_array = list(iv_obj.non_event_rate_array)
  81. self.split_points = list(iv_obj.split_points)
  82. self.iv = iv_obj.iv
  83. # new attribute since ver 1.10
  84. if hasattr(iv_obj, "optimal_metric_array"):
  85. self.optimal_metric_array = list(iv_obj.optimal_metric_array)
  86. def generate_pb_dict(self):
  87. result = {
  88. "woe_array": self.woe_array,
  89. "iv_array": self.iv_array,
  90. "event_count_array": self.event_count_array,
  91. "non_event_count_array": self.non_event_count_array,
  92. "event_rate_array": self.event_rate_array,
  93. "non_event_rate_array": self.non_event_rate_array,
  94. "split_points": self.split_points,
  95. "iv": self.iv,
  96. "is_woe_monotonic": self.is_woe_monotonic,
  97. "bin_nums": self.bin_nums,
  98. "bin_anonymous": self.bin_anonymous,
  99. "optimal_metric_array": self.optimal_metric_array
  100. }
  101. return result
  102. class SplitPointsResult(object):
  103. def __init__(self):
  104. self.split_results = {}
  105. self.optimal_metric = {}
  106. def put_col_split_points(self, col_name, split_points):
  107. self.split_results[col_name] = split_points
  108. def put_col_optimal_metric_array(self, col_name, metric_array):
  109. self.optimal_metric[col_name] = metric_array
  110. @property
  111. def all_split_points(self):
  112. return self.split_results
  113. @property
  114. def all_optimal_metric(self):
  115. return self.optimal_metric
  116. def get_split_points_array(self, col_names):
  117. split_points_result = []
  118. for col_name in col_names:
  119. if col_name not in self.split_results:
  120. continue
  121. split_points_result.append(self.split_results[col_name])
  122. return np.array(split_points_result)
  123. def to_json(self):
  124. return {k: list(v) for k, v in self.split_results.items()}
  125. class BinResults(object):
  126. def __init__(self):
  127. self.all_cols_results = {} # {col_name: BinColResult}
  128. self.role = ''
  129. self.party_id = ''
  130. def set_role_party(self, role, party_id):
  131. self.role = role
  132. self.party_id = party_id
  133. def put_col_results(self, col_name, col_results: BinColResults):
  134. ori_col_results = self.all_cols_results.get(col_name)
  135. if ori_col_results is not None:
  136. col_results.set_split_points(ori_col_results.get_split_points())
  137. self.all_cols_results[col_name] = col_results
  138. def put_col_split_points(self, col_name, split_points):
  139. col_results = self.all_cols_results.get(col_name, BinColResults())
  140. col_results.set_split_points(split_points)
  141. self.all_cols_results[col_name] = col_results
  142. def query_split_points(self, col_name):
  143. col_results = self.all_cols_results.get(col_name)
  144. if col_results is None:
  145. LOGGER.warning("Querying non-exist split_points")
  146. return None
  147. return col_results.split_points
  148. def put_optimal_metric_array(self, col_name, metric_array):
  149. col_results = self.all_cols_results.get(col_name, BinColResults())
  150. col_results.set_optimal_metric(metric_array)
  151. self.all_cols_results[col_name] = col_results
  152. @property
  153. def all_split_points(self):
  154. results = {}
  155. for col_name, col_result in self.all_cols_results.items():
  156. results[col_name] = col_result.get_split_points()
  157. return results
  158. @property
  159. def all_ivs(self):
  160. return [(col_name, x.iv) for col_name, x in self.all_cols_results.items()]
  161. @property
  162. def all_woes(self):
  163. return {col_name: x.woe_array for col_name, x in self.all_cols_results.items()}
  164. @property
  165. def all_monotonic(self):
  166. return {col_name: x.is_woe_monotonic for col_name, x in self.all_cols_results.items()}
  167. @property
  168. def all_optimal_metric(self):
  169. return {col_name: x.optimal_metric_array for col_name, x in self.all_cols_results.items()}
  170. def summary(self, split_points=None):
  171. if split_points is None:
  172. split_points = {}
  173. for col_name, x in self.all_cols_results.items():
  174. sp = x.get_split_points().tolist()
  175. split_points[col_name] = sp
  176. # split_points = {col_name: x.split_points for col_name, x in self.all_cols_results.items()}
  177. return {"iv": self.all_ivs,
  178. "woe": self.all_woes,
  179. "monotonic": self.all_monotonic,
  180. "split_points": split_points}
  181. def generated_pb(self, split_points=None):
  182. col_result_dict = {}
  183. if split_points is not None:
  184. for col_name, sp in split_points.items():
  185. self.put_col_split_points(col_name, sp)
  186. for col_name, col_bin_result in self.all_cols_results.items():
  187. bin_res_dict = col_bin_result.generate_pb_dict()
  188. # LOGGER.debug(f"col name: {col_name}, bin_res_dict: {bin_res_dict}")
  189. col_result_dict[col_name] = feature_binning_param_pb2.IVParam(**bin_res_dict)
  190. # LOGGER.debug("In generated_pb, role: {}, party_id: {}".format(self.role, self.party_id))
  191. result_pb = feature_binning_param_pb2.FeatureBinningResult(binning_result=col_result_dict,
  192. role=self.role,
  193. party_id=str(self.party_id))
  194. return result_pb
  195. def reconstruct(self, result_pb):
  196. self.role = result_pb.role
  197. self.party_id = result_pb.party_id
  198. binning_result = dict(result_pb.binning_result)
  199. for col_name, col_bin_result in binning_result.items():
  200. col_bin_obj = BinColResults()
  201. col_bin_obj.reconstruct(col_bin_result)
  202. self.all_cols_results[col_name] = col_bin_obj
  203. return self
  204. def update_anonymous(self, anonymous_header_dict):
  205. all_cols_results = dict()
  206. for col_name, col_bin_result in self.all_cols_results.items():
  207. updated_col_name = anonymous_header_dict[col_name]
  208. all_cols_results[updated_col_name] = col_bin_result
  209. self.all_cols_results = all_cols_results
  210. return self
  211. class MultiClassBinResult(BinResults):
  212. def __init__(self, labels):
  213. super().__init__()
  214. self.labels = labels
  215. if len(self.labels) == 2:
  216. self.is_multi_class = False
  217. self.bin_results = [BinResults()]
  218. else:
  219. self.is_multi_class = True
  220. self.bin_results = [BinResults() for _ in range(len(self.labels))]
  221. def set_role_party(self, role, party_id):
  222. self.role = role
  223. self.party_id = party_id
  224. for br in self.bin_results:
  225. br.set_role_party(role, party_id)
  226. def put_col_results(self, col_name, col_results: BinColResults, label_idx=0):
  227. self.bin_results[label_idx].put_col_results(col_name, col_results)
  228. def summary(self, split_points=None):
  229. if not self.is_multi_class:
  230. return {"result": self.bin_results[0].summary(split_points)}
  231. return {label: self.bin_results[label_idx].summary(split_points) for
  232. label_idx, label in enumerate(self.labels)}
  233. def put_col_split_points(self, col_name, split_points, label_idx=None):
  234. if label_idx is None:
  235. for br in self.bin_results:
  236. br.put_col_split_points(col_name, split_points)
  237. else:
  238. self.bin_results[label_idx].put_col_split_points(col_name, split_points)
  239. def put_optimal_metric_array(self, col_name, metric_array, label_idx=None):
  240. if label_idx is None:
  241. for br in self.bin_results:
  242. br.put_optimal_metric_array(col_name, metric_array)
  243. else:
  244. self.bin_results[label_idx].put_optimal_metric_array(col_name, metric_array)
  245. def generated_pb_list(self, split_points=None):
  246. res = []
  247. for br in self.bin_results:
  248. res.append(br.generated_pb(split_points))
  249. return res
  250. @staticmethod
  251. def reconstruct(result_pb, labels=None):
  252. if not isinstance(result_pb, list):
  253. result_pb = [result_pb]
  254. if labels is None:
  255. if len(result_pb) <= 1:
  256. labels = [0, 1]
  257. else:
  258. labels = list(range(len(result_pb)))
  259. result = MultiClassBinResult(labels)
  260. for idx, pb in enumerate(result_pb):
  261. result.bin_results[idx].reconstruct(pb)
  262. return result
  263. def update_anonymous(self, anonymous_header_dict):
  264. for idx in range(len(self.bin_results)):
  265. self.bin_results[idx].update_anonymous(anonymous_header_dict)
  266. @property
  267. def all_split_points(self):
  268. return self.bin_results[0].all_split_points