psi.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import functools
  2. import copy
  3. import numpy as np
  4. from federatedml.feature.binning.quantile_binning import QuantileBinning
  5. from federatedml.param.feature_binning_param import FeatureBinningParam
  6. from federatedml.util import consts
  7. from federatedml.feature.fate_element_type import NoneType
  8. from federatedml.feature.instance import Instance
  9. from federatedml.feature.sparse_vector import SparseVector
  10. from federatedml.model_base import ModelBase
  11. from federatedml.param.psi_param import PSIParam
  12. from federatedml.util import LOGGER
  13. from federatedml.protobuf.generated.psi_model_param_pb2 import PsiSummary, FeaturePsi
  14. from federatedml.protobuf.generated.psi_model_meta_pb2 import PSIMeta
  15. from federatedml.util import abnormal_detection
  16. ROUND_NUM = 6
  17. def map_partition_handle(iterable, feat_num=10, max_bin_num=20, is_sparse=False, missing_val=NoneType()):
  18. count_bin = np.zeros((feat_num, max_bin_num))
  19. row_idx = np.array([i for i in range(feat_num)])
  20. for k, v in iterable:
  21. # last bin is for missing value
  22. if is_sparse:
  23. feature_dict = v.features.sparse_vec
  24. arr = np.zeros(feat_num, dtype=np.int64) + max_bin_num - 1 # max_bin_num - 1 is the missing bin val
  25. arr[list(feature_dict.keys())] = list(feature_dict.values())
  26. else:
  27. arr = v.features
  28. arr[arr == missing_val] = max_bin_num - 1
  29. count_bin[row_idx, arr.astype(np.int64)] += 1
  30. return count_bin
  31. def map_partition_reduce(arr1, arr2):
  32. return arr1 + arr2
  33. def psi_computer(expect_counter_list, actual_counter_list, expect_sample_count, actual_sample_count):
  34. psi_rs = []
  35. for exp_counter, acu_counter in zip(expect_counter_list, actual_counter_list):
  36. feat_psi = {}
  37. for key in exp_counter:
  38. feat_psi[key] = psi_val(exp_counter[key] / expect_sample_count, acu_counter[key] / actual_sample_count)
  39. total_psi = 0
  40. for k in feat_psi:
  41. total_psi += feat_psi[k]
  42. feat_psi['total_psi'] = total_psi
  43. psi_rs.append(feat_psi)
  44. return psi_rs
  45. def psi_val(expected_perc, actual_perc):
  46. if expected_perc == 0:
  47. expected_perc = 1e-6
  48. if actual_perc == 0:
  49. actual_perc = 1e-6
  50. return (actual_perc - expected_perc) * np.log(actual_perc / expected_perc)
  51. def psi_val_arr(expected_arr, actual_arr, sample_num):
  52. expected_arr = expected_arr / sample_num
  53. actual_arr = actual_arr / sample_num
  54. expected_arr[expected_arr == 0] = 1e-6
  55. actual_arr[actual_arr == 0] = 1e-6
  56. psi_rs = (actual_arr - expected_arr) * np.log(actual_arr / expected_arr)
  57. return psi_rs
  58. def count_rs_to_dict(arrs):
  59. dicts = []
  60. for i in arrs:
  61. rs_dict = {}
  62. for k, v in enumerate(i):
  63. rs_dict[k] = v
  64. dicts.append(rs_dict)
  65. return dicts
  66. def np_nan_to_nonetype(inst):
  67. arr = inst.features
  68. index = np.isnan(arr)
  69. if index.any():
  70. inst = copy.deepcopy(inst)
  71. arr = arr.astype(object)
  72. arr[index] = NoneType()
  73. inst.features = arr
  74. return inst
  75. class PSI(ModelBase):
  76. def __init__(self):
  77. super(PSI, self).__init__()
  78. self.model_param = PSIParam()
  79. self.max_bin_num = 20
  80. self.tag_id_mapping = {}
  81. self.id_tag_mapping = {}
  82. self.count1, self.count2 = None, None
  83. self.actual_table, self.expect_table = None, None
  84. self.data_bin1, self.data_bin2 = None, None
  85. self.bin_split_points = None
  86. self.bin_sparse_points = None
  87. self.psi_rs = None
  88. self.total_scores = None
  89. self.all_feature_list = None
  90. self.dense_missing_val = NoneType()
  91. self.binning_error = consts.DEFAULT_RELATIVE_ERROR
  92. self.interval_perc1 = None
  93. self.interval_perc2 = None
  94. self.str_intervals = None
  95. self.binning_obj = None
  96. def _init_model(self, model: PSIParam):
  97. self.max_bin_num = model.max_bin_num
  98. self.need_run = model.need_run
  99. self.dense_missing_val = NoneType() if model.dense_missing_val is None else model.dense_missing_val
  100. self.binning_error = model.binning_error
  101. @staticmethod
  102. def check_table_content(tb):
  103. if not tb.count() > 0:
  104. raise ValueError('input table must contains at least 1 sample')
  105. first_ = tb.take(1)[0][1]
  106. if isinstance(first_, Instance):
  107. return True
  108. else:
  109. raise ValueError('unknown input format')
  110. @staticmethod
  111. def is_sparse(tb):
  112. return isinstance(tb.take(1)[0][1].features, SparseVector)
  113. @staticmethod
  114. def check_duplicates(l_):
  115. s = set(l_)
  116. recorded = set()
  117. new_l = []
  118. for i in l_:
  119. if i in s and i not in recorded:
  120. new_l.append(i)
  121. recorded.add(i)
  122. return new_l
  123. @staticmethod
  124. def get_string_interval(data_split_points, id_tag_mapping, missing_bin_idx):
  125. # generate string interval from bin_split_points
  126. feature_interval = []
  127. for feat_idx, interval in enumerate(data_split_points):
  128. idx2intervals = {}
  129. l_ = list(interval)
  130. l_[-1] = 'inf'
  131. l_.insert(0, '-inf')
  132. idx = 0
  133. for s, e in zip(l_[:-1], l_[1:]):
  134. interval_str = str(id_tag_mapping[feat_idx])
  135. if s != '-inf':
  136. interval_str = str(np.round(s, ROUND_NUM)) + "<" + interval_str
  137. if e != 'inf':
  138. interval_str = interval_str + "<=" + str(np.round(e, ROUND_NUM))
  139. idx2intervals[idx] = interval_str
  140. idx += 1
  141. idx2intervals[missing_bin_idx] = 'missing'
  142. feature_interval.append(idx2intervals)
  143. return feature_interval
  144. @staticmethod
  145. def post_process_result(rs_dict, interval_dict,):
  146. # convert bin idx to str intervals
  147. # then divide count by sample num to get percentage
  148. #
  149. rs_val_list, interval_list = [], []
  150. for key in sorted(interval_dict.keys()):
  151. corresponding_str_interval = interval_dict[key]
  152. val = rs_dict[key]
  153. rs_val_list.append(np.round(val, ROUND_NUM))
  154. interval_list.append(corresponding_str_interval)
  155. return rs_val_list, interval_list
  156. @staticmethod
  157. def count_dict_to_percentage(count_rs, sample_num):
  158. for c in count_rs:
  159. for k in c:
  160. c[k] = c[k] / sample_num
  161. return count_rs
  162. @staticmethod
  163. def convert_missing_val(table):
  164. new_table = table.mapValues(np_nan_to_nonetype)
  165. new_table.schema = table.schema
  166. return new_table
  167. def fit(self, expect_table, actual_table):
  168. LOGGER.info('start psi computing')
  169. header1 = expect_table.schema['header']
  170. header2 = actual_table.schema['header']
  171. if not set(header1) == set(header2):
  172. raise ValueError('table header must be the same while computing psi values')
  173. # baseline table should not contain empty columns
  174. abnormal_detection.empty_column_detection(expect_table)
  175. self.all_feature_list = header1
  176. # make sure no duplicate features
  177. self.all_feature_list = self.check_duplicates(self.all_feature_list)
  178. # kv bi-directional mapping
  179. self.tag_id_mapping = {v: k for k, v in enumerate(self.all_feature_list)}
  180. self.id_tag_mapping = {k: v for k, v in enumerate(self.all_feature_list)}
  181. if not self.is_sparse(expect_table): # convert missing value: nan to NoneType
  182. expect_table = self.convert_missing_val(expect_table)
  183. if not self.is_sparse(actual_table): # convert missing value: nan to NoneType
  184. actual_table = self.convert_missing_val(actual_table)
  185. if not (self.check_table_content(expect_table) and self.check_table_content(actual_table)):
  186. raise ValueError('contents of input table must be instances of class "Instance"')
  187. param = FeatureBinningParam(method=consts.QUANTILE, bin_num=self.max_bin_num, local_only=True,
  188. error=self.binning_error)
  189. binning_obj = QuantileBinning(params=param, abnormal_list=[NoneType()], allow_duplicate=False)
  190. binning_obj.fit_split_points(expect_table)
  191. data_bin, bin_split_points, bin_sparse_points = binning_obj.convert_feature_to_bin(expect_table)
  192. LOGGER.debug('bin split points is {}, shape is {}'.format(bin_split_points, bin_split_points.shape))
  193. self.binning_obj = binning_obj
  194. self.data_bin1 = data_bin
  195. self.bin_split_points = bin_split_points
  196. self.bin_sparse_points = bin_sparse_points
  197. LOGGER.debug('expect table binning done')
  198. count_func1 = functools.partial(map_partition_handle,
  199. feat_num=len(self.all_feature_list),
  200. max_bin_num=self.max_bin_num + 1, # an additional bin for missing value
  201. missing_val=self.dense_missing_val,
  202. is_sparse=self.is_sparse(self.data_bin1))
  203. map_rs1 = self.data_bin1.applyPartitions(count_func1)
  204. count1 = count_rs_to_dict(map_rs1.reduce(map_partition_reduce))
  205. data_bin2, bin_split_points2, bin_sparse_points2 = binning_obj.convert_feature_to_bin(actual_table)
  206. self.data_bin2 = data_bin2
  207. LOGGER.debug('actual table binning done')
  208. count_func2 = functools.partial(map_partition_handle,
  209. feat_num=len(self.all_feature_list),
  210. max_bin_num=self.max_bin_num + 1, # an additional bin for missing value
  211. missing_val=self.dense_missing_val,
  212. is_sparse=self.is_sparse(self.data_bin2))
  213. map_rs2 = self.data_bin2.applyPartitions(count_func2)
  214. count2 = count_rs_to_dict(map_rs2.reduce(map_partition_reduce))
  215. self.count1, self.count2 = count1, count2
  216. LOGGER.info('psi counting done')
  217. # compute psi from counting result
  218. psi_result = psi_computer(count1, count2, expect_table.count(), actual_table.count())
  219. self.psi_rs = psi_result
  220. # get total psi score of features
  221. total_scores = {}
  222. for idx, rs in enumerate(self.psi_rs):
  223. feat_name = self.id_tag_mapping[idx]
  224. total_scores[feat_name] = rs['total_psi']
  225. self.total_scores = total_scores
  226. # id-feature mapping convert, str interval computation
  227. self.str_intervals = self.get_string_interval(bin_split_points, self.id_tag_mapping,
  228. missing_bin_idx=self.max_bin_num)
  229. self.interval_perc1 = self.count_dict_to_percentage(copy.deepcopy(count1), expect_table.count())
  230. self.interval_perc2 = self.count_dict_to_percentage(copy.deepcopy(count2), actual_table.count())
  231. self.set_summary(self.generate_summary())
  232. LOGGER.info('psi computation done')
  233. def generate_summary(self):
  234. return {'psi_scores': self.total_scores}
  235. def export_model(self):
  236. if not self.need_run:
  237. return None
  238. psi_summary = PsiSummary()
  239. psi_summary.total_score.update(self.total_scores)
  240. LOGGER.debug('psi total score is {}'.format(dict(psi_summary.total_score)))
  241. psi_summary.model_name = consts.PSI
  242. feat_psi_list = []
  243. for id_ in self.id_tag_mapping:
  244. feat_psi_summary = FeaturePsi()
  245. feat_name = self.id_tag_mapping[id_]
  246. feat_psi_summary.feature_name = feat_name
  247. interval_psi, str_intervals = self.post_process_result(self.psi_rs[id_], self.str_intervals[id_])
  248. interval_perc1, _ = self.post_process_result(self.interval_perc1[id_], self.str_intervals[id_])
  249. interval_perc2, _ = self.post_process_result(self.interval_perc2[id_], self.str_intervals[id_])
  250. feat_psi_summary.psi.extend(interval_psi)
  251. feat_psi_summary.expect_perc.extend(interval_perc1)
  252. feat_psi_summary.actual_perc.extend(interval_perc2)
  253. feat_psi_summary.interval.extend(str_intervals)
  254. feat_psi_list.append(feat_psi_summary)
  255. psi_summary.feature_psi.extend(feat_psi_list)
  256. LOGGER.debug('export model done')
  257. meta = PSIMeta()
  258. meta.max_bin_num = self.max_bin_num
  259. return {'PSIParam': psi_summary, 'PSIMeta': meta}