evaluation.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from collections import defaultdict
  17. import math
  18. from federatedml.util import LOGGER
  19. from federatedml.model_base import Metric, MetricMeta
  20. from federatedml.param import EvaluateParam
  21. from federatedml.util import consts
  22. from federatedml.model_base import ModelBase
  23. from federatedml.evaluation.metric_interface import MetricInterface
  24. from federatedml.statistic.data_overview import predict_detail_str_to_dict
  25. import numpy as np
  26. class Evaluation(ModelBase):
  27. def __init__(self):
  28. super().__init__()
  29. self.model_param = EvaluateParam()
  30. self.eval_results = defaultdict(list)
  31. self.save_single_value_metric_list = [consts.AUC,
  32. consts.EXPLAINED_VARIANCE,
  33. consts.MEAN_ABSOLUTE_ERROR,
  34. consts.MEAN_SQUARED_ERROR,
  35. consts.MEAN_SQUARED_LOG_ERROR,
  36. consts.MEDIAN_ABSOLUTE_ERROR,
  37. consts.R2_SCORE,
  38. consts.ROOT_MEAN_SQUARED_ERROR,
  39. consts.JACCARD_SIMILARITY_SCORE,
  40. consts.ADJUSTED_RAND_SCORE,
  41. consts.FOWLKES_MALLOWS_SCORE,
  42. consts.DAVIES_BOULDIN_INDEX
  43. ]
  44. self.special_metric_list = [consts.PSI]
  45. self.clustering_intra_metric_list = [
  46. consts.DAVIES_BOULDIN_INDEX, consts.DISTANCE_MEASURE]
  47. self.metrics = None
  48. self.round_num = 6
  49. self.eval_type = None
  50. # where to call metric computations
  51. self.metric_interface: MetricInterface = None
  52. self.psi_train_scores, self.psi_validate_scores = None, None
  53. self.psi_train_labels, self.psi_validate_labels = None, None
  54. # multi unfold setting
  55. self.need_unfold_multi_result = False
  56. # summaries
  57. self.metric_summaries = {}
  58. def _init_model(self, model):
  59. self.model_param = model
  60. self.eval_type = self.model_param.eval_type
  61. self.pos_label = self.model_param.pos_label
  62. self.need_unfold_multi_result = self.model_param.unfold_multi_result
  63. self.metrics = model.metrics
  64. self.metric_interface = MetricInterface(
  65. pos_label=self.pos_label, eval_type=self.eval_type, )
  66. def _run_data(self, data_sets=None, stage=None):
  67. if not self.need_run:
  68. return
  69. data = {}
  70. for data_key in data_sets:
  71. if data_sets[data_key].get("data", None):
  72. data[data_key] = data_sets[data_key]["data"]
  73. if stage == "fit":
  74. self.data_output = self.fit(data)
  75. else:
  76. LOGGER.warning("Evaluation has not transform, return")
  77. def split_data_with_type(self, data: list) -> dict:
  78. split_result = defaultdict(list)
  79. for value in data:
  80. mode = value[1][-1]
  81. split_result[mode].append(value)
  82. return split_result
  83. def _classification_and_regression_extract(self, data):
  84. """
  85. extract labels and predict results from data in classification/regression type format
  86. """
  87. labels = []
  88. pred_scores = []
  89. pred_labels = []
  90. for d in data:
  91. labels.append(d[1][0])
  92. pred_labels.append(d[1][1])
  93. pred_scores.append(d[1][2])
  94. if self.eval_type == consts.BINARY or self.eval_type == consts.REGRESSION:
  95. if self.pos_label and self.eval_type == consts.BINARY:
  96. labels_arr = np.array(labels)
  97. labels_arr[labels_arr == self.pos_label] = 1
  98. labels_arr[labels_arr != self.pos_label] = 0
  99. labels = list(labels_arr)
  100. pred_results = pred_scores
  101. else:
  102. pred_results = pred_labels
  103. return labels, pred_results
  104. def _clustering_extract(self, data):
  105. """
  106. extract data according to data format
  107. """
  108. true_cluster_index, predicted_cluster_index = [], []
  109. intra_cluster_data, inter_cluster_dist = {
  110. 'avg_dist': [], 'max_radius': []}, []
  111. run_intra_metrics = False # run intra metrics or outer metrics ?
  112. if len(data[0][1]) == 3:
  113. # [int int] -> [true_label, predicted label] -> outer metric
  114. # [int np.array] - > [predicted label, distance] -> need no metric computation
  115. if not (
  116. isinstance(
  117. data[0][1][0],
  118. int) and isinstance(
  119. data[0][1][1],
  120. int)):
  121. return None, None, run_intra_metrics
  122. if len(data[0][1]) == 5: # the input format is for intra metrics
  123. run_intra_metrics = True
  124. cluster_index_list = []
  125. for d in data:
  126. if run_intra_metrics:
  127. cluster_index_list.append(d[0])
  128. intra_cluster_data['avg_dist'].append(d[1][1])
  129. intra_cluster_data['max_radius'].append(d[1][2])
  130. if len(inter_cluster_dist) == 0:
  131. inter_cluster_dist += d[1][3]
  132. else:
  133. true_cluster_index.append(d[1][0])
  134. predicted_cluster_index.append(d[1][1])
  135. # if cluster related data exists, sort by cluster index
  136. if len(cluster_index_list) != 0:
  137. to_sort = list(zip(cluster_index_list,
  138. intra_cluster_data['avg_dist'],
  139. intra_cluster_data['max_radius']))
  140. sort_rs = sorted(to_sort, key=lambda x: x[0]) # cluster index
  141. intra_cluster_data['avg_dist'] = [i[1] for i in sort_rs]
  142. intra_cluster_data['max_radius'] = [i[2] for i in sort_rs]
  143. return (
  144. true_cluster_index,
  145. predicted_cluster_index,
  146. run_intra_metrics) if not run_intra_metrics else (
  147. intra_cluster_data,
  148. inter_cluster_dist,
  149. run_intra_metrics)
  150. def _evaluate_classification_and_regression_metrics(self, mode, data):
  151. labels, pred_results = self._classification_and_regression_extract(
  152. data)
  153. eval_result = defaultdict(list)
  154. for eval_metric in self.metrics:
  155. if eval_metric not in self.special_metric_list:
  156. res = getattr(
  157. self.metric_interface,
  158. eval_metric)(
  159. labels,
  160. pred_results)
  161. if res is not None:
  162. try:
  163. if math.isinf(res):
  164. res = float(-9999999)
  165. LOGGER.info("res is inf, set to {}".format(res))
  166. except BaseException:
  167. pass
  168. eval_result[eval_metric].append(mode)
  169. eval_result[eval_metric].append(res)
  170. elif eval_metric == consts.PSI:
  171. if mode == 'train':
  172. self.psi_train_scores = pred_results
  173. self.psi_train_labels = labels
  174. elif mode == 'validate':
  175. self.psi_validate_scores = pred_results
  176. self.psi_validate_labels = labels
  177. if self.psi_train_scores is not None and self.psi_validate_scores is not None:
  178. res = self.metric_interface.psi(
  179. self.psi_train_scores,
  180. self.psi_validate_scores,
  181. self.psi_train_labels,
  182. self.psi_validate_labels)
  183. eval_result[eval_metric].append(mode)
  184. eval_result[eval_metric].append(res)
  185. # delete saved scores after computing a psi pair
  186. self.psi_train_scores, self.psi_validate_scores = None, None
  187. return eval_result
  188. def _evaluate_clustering_metrics(self, mode, data):
  189. eval_result = defaultdict(list)
  190. rs0, rs1, run_outer_metric = self._clustering_extract(data)
  191. if rs0 is None and rs1 is None: # skip evaluation computation if get this input format
  192. LOGGER.debug(
  193. 'skip computing, this clustering format is not for metric computation')
  194. return eval_result
  195. if not run_outer_metric:
  196. no_label = set(rs0) == {None}
  197. if no_label:
  198. LOGGER.debug(
  199. 'no label found in clustering result, skip metric computation')
  200. return eval_result
  201. for eval_metric in self.metrics:
  202. # if input format and required metrics matches ? XNOR
  203. if not ((not (eval_metric in self.clustering_intra_metric_list) and not run_outer_metric) +
  204. ((eval_metric in self.clustering_intra_metric_list) and run_outer_metric)):
  205. LOGGER.warning(
  206. 'input data format does not match current clustering metric: {}'.format(eval_metric))
  207. continue
  208. LOGGER.debug('clustering_metrics is {}'.format(eval_metric))
  209. if run_outer_metric:
  210. if eval_metric == consts.DISTANCE_MEASURE:
  211. res = getattr(
  212. self.metric_interface, eval_metric)(
  213. rs0['avg_dist'], rs1, rs0['max_radius'])
  214. else:
  215. res = getattr(
  216. self.metric_interface,
  217. eval_metric)(
  218. rs0['avg_dist'],
  219. rs1)
  220. else:
  221. res = getattr(self.metric_interface, eval_metric)(rs0, rs1)
  222. eval_result[eval_metric].append(mode)
  223. eval_result[eval_metric].append(res)
  224. return eval_result
  225. @staticmethod
  226. def _check_clustering_input(data):
  227. # one evaluation component is only available for one kmeans component
  228. # in current version
  229. input_num = len(data.items())
  230. if input_num > 1:
  231. raise ValueError(
  232. 'multiple input detected, '
  233. 'one evaluation component is only available '
  234. 'for one clustering(kmean) component in current version')
  235. @staticmethod
  236. def _unfold_multi_result(score_list):
  237. """
  238. one-vs-rest transformation: multi classification result to several binary classification results
  239. """
  240. binary_result = {}
  241. for key, multi_result in score_list:
  242. true_label = multi_result[0]
  243. predicted_label = multi_result[1]
  244. multi_score = predict_detail_str_to_dict(multi_result[3])
  245. data_type = multi_result[-1]
  246. # to binary predict result format
  247. for multi_label in multi_score:
  248. bin_label = 1 if str(multi_label) == str(true_label) else 0
  249. bin_predicted_label = 1 if str(
  250. multi_label) == str(predicted_label) else 0
  251. bin_score = multi_score[multi_label]
  252. neg_bin_score = 1 - bin_score
  253. result_list = [
  254. bin_label, bin_predicted_label, bin_score, {
  255. 1: bin_score, 0: neg_bin_score}, data_type]
  256. if multi_label not in binary_result:
  257. binary_result[multi_label] = []
  258. binary_result[multi_label].append((key, result_list))
  259. return binary_result
  260. def evaluate_metrics(self, mode: str, data: list) -> dict:
  261. eval_result = None
  262. if self.eval_type != consts.CLUSTERING:
  263. eval_result = self._evaluate_classification_and_regression_metrics(
  264. mode, data)
  265. elif self.eval_type == consts.CLUSTERING:
  266. LOGGER.debug('running clustering')
  267. eval_result = self._evaluate_clustering_metrics(mode, data)
  268. return eval_result
  269. def obtain_data(self, data_list):
  270. return data_list
  271. def check_data(self, data):
  272. if len(data) <= 0:
  273. return
  274. if self.eval_type == consts.CLUSTERING:
  275. self._check_clustering_input(data)
  276. else:
  277. for key, eval_data in data.items():
  278. if eval_data is None:
  279. continue
  280. sample = eval_data.take(1)[0]
  281. # label, predict_type, predict_score, predict_detail, type
  282. if not isinstance(
  283. sample[1].features, list) or len(
  284. sample[1].features) != 5:
  285. raise ValueError(
  286. 'length of table header mismatch, expected length is 5, got:{},'
  287. 'please check the input of the Evaluation Module, result of '
  288. 'cross validation is not supported.'.format(sample))
  289. def fit(self, data, return_result=False):
  290. self.check_data(data)
  291. LOGGER.debug(f'running eval, data: {data}')
  292. self.eval_results.clear()
  293. for (key, eval_data) in data.items():
  294. if eval_data is None:
  295. LOGGER.debug(
  296. 'data with {} is None, skip metric computation'.format(key))
  297. continue
  298. collected_data = list(eval_data.collect())
  299. if len(collected_data) == 0:
  300. continue
  301. eval_data_local = []
  302. for k, v in collected_data:
  303. eval_data_local.append((k, v.features))
  304. split_data_with_label = self.split_data_with_type(eval_data_local)
  305. for mode, data in split_data_with_label.items():
  306. eval_result = self.evaluate_metrics(mode, data)
  307. self.eval_results[key].append(eval_result)
  308. if self.need_unfold_multi_result and self.eval_type == consts.MULTY:
  309. unfold_binary_eval_result = defaultdict(list)
  310. # set work mode to binary evaluation
  311. self.eval_type = consts.BINARY
  312. self.metric_interface.eval_type = consts.ONE_VS_REST
  313. back_up_metric = self.metrics
  314. self.metrics = [consts.AUC, consts.KS]
  315. for mode, data in split_data_with_label.items():
  316. unfold_multi_data = self._unfold_multi_result(
  317. eval_data_local)
  318. for multi_label, marginal_bin_result in unfold_multi_data.items():
  319. eval_result = self.evaluate_metrics(
  320. mode, marginal_bin_result)
  321. new_key = key + '_class_{}'.format(multi_label)
  322. unfold_binary_eval_result[new_key].append(eval_result)
  323. self.callback_ovr_metric_data(unfold_binary_eval_result)
  324. # recover work mode
  325. self.eval_type = consts.MULTY
  326. self.metric_interface.eval_type = consts.MULTY
  327. self.metrics = back_up_metric
  328. return self.callback_metric_data(
  329. self.eval_results,
  330. return_single_val_metrics=return_result)
  331. def __save_single_value(
  332. self,
  333. result,
  334. metric_name,
  335. metric_namespace,
  336. eval_name):
  337. metric_type = 'EVALUATION_SUMMARY'
  338. if eval_name in consts.ALL_CLUSTER_METRICS:
  339. metric_type = 'CLUSTERING_EVALUATION_SUMMARY'
  340. self.tracker.log_metric_data(
  341. metric_namespace, metric_name, [
  342. Metric(
  343. eval_name, np.round(
  344. result, self.round_num))])
  345. self.tracker.set_metric_meta(
  346. metric_namespace, metric_name, MetricMeta(
  347. name=metric_name, metric_type=metric_type))
  348. def __save_curve_data(
  349. self,
  350. x_axis_list,
  351. y_axis_list,
  352. metric_name,
  353. metric_namespace):
  354. points = []
  355. for i, value in enumerate(x_axis_list):
  356. if isinstance(value, float):
  357. value = np.round(value, self.round_num)
  358. points.append((value, np.round(y_axis_list[i], self.round_num)))
  359. points.sort(key=lambda x: x[0])
  360. metric_points = [Metric(point[0], point[1]) for point in points]
  361. self.tracker.log_metric_data(
  362. metric_namespace, metric_name, metric_points)
  363. def __save_curve_meta(
  364. self,
  365. metric_name,
  366. metric_namespace,
  367. metric_type,
  368. unit_name=None,
  369. ordinate_name=None,
  370. curve_name=None,
  371. best=None,
  372. pair_type=None,
  373. thresholds=None):
  374. extra_metas = {}
  375. metric_type = "_".join([metric_type, "EVALUATION"])
  376. key_list = [
  377. "unit_name",
  378. "ordinate_name",
  379. "curve_name",
  380. "best",
  381. "pair_type",
  382. "thresholds"]
  383. for key in key_list:
  384. value = locals()[key]
  385. if value:
  386. if key == "thresholds":
  387. value = np.round(value, self.round_num).tolist()
  388. extra_metas[key] = value
  389. self.tracker.set_metric_meta(metric_namespace, metric_name, MetricMeta(
  390. name=metric_name, metric_type=metric_type, extra_metas=extra_metas))
  391. @staticmethod
  392. def __multi_class_label_padding(metrics, label_indices):
  393. # in case some labels don't appear when running homo-multi-class algo
  394. label_num = np.max(label_indices) + 1
  395. index_result_mapping = dict(zip(label_indices, metrics))
  396. new_metrics, new_label_indices = [], []
  397. for i in range(label_num):
  398. if i in index_result_mapping:
  399. new_metrics.append(index_result_mapping[i])
  400. else:
  401. new_metrics.append(0.0)
  402. new_label_indices.append(i)
  403. return new_metrics, new_label_indices
  404. @staticmethod
  405. def __filt_override_unit_ordinate_coordinate(x_sets, y_sets):
  406. max_y_dict = {}
  407. for idx, x_value in enumerate(x_sets):
  408. if x_value not in max_y_dict:
  409. max_y_dict[x_value] = {"max_y": y_sets[idx], "idx": idx}
  410. else:
  411. max_y = max_y_dict[x_value]["max_y"]
  412. if max_y < y_sets[idx]:
  413. max_y_dict[x_value] = {"max_y": y_sets[idx], "idx": idx}
  414. x = []
  415. y = []
  416. idx_list = []
  417. for key, value in max_y_dict.items():
  418. x.append(key)
  419. y.append(value["max_y"])
  420. idx_list.append(value["idx"])
  421. return x, y, idx_list
  422. def __process_single_value_data(self, metric, metric_res):
  423. single_val_metric = None
  424. if metric in self.save_single_value_metric_list or \
  425. (metric == consts.ACCURACY and self.eval_type == consts.MULTY):
  426. single_val_metric = metric_res[1]
  427. elif metric == consts.KS:
  428. best_ks, fpr, tpr, thresholds, cuts = metric_res[1]
  429. single_val_metric = best_ks
  430. elif metric in [consts.RECALL, consts.PRECISION] and self.eval_type == consts.MULTY:
  431. pos_score = metric_res[1][0]
  432. single_val_metric = float(np.array(pos_score).mean())
  433. return single_val_metric
  434. @staticmethod
  435. def __filter_duplicate_roc_data_point(fpr, tpr, thresholds):
  436. data_point_set = set()
  437. new_fpr, new_tpr, new_threshold = [], [], []
  438. for fpr_, tpr_, thres in zip(fpr, tpr, thresholds):
  439. if (fpr_, tpr_, thres) not in data_point_set:
  440. data_point_set.add((fpr_, tpr_, thres))
  441. new_fpr.append(fpr_)
  442. new_tpr.append(tpr_)
  443. new_threshold.append(thres)
  444. return new_fpr, new_tpr, new_threshold
  445. def __save_roc_curve(
  446. self,
  447. data_name,
  448. metric_name,
  449. metric_namespace,
  450. metric_res):
  451. fpr, tpr, thresholds, _ = metric_res
  452. fpr, tpr, thresholds = self.__filter_duplicate_roc_data_point(
  453. fpr, tpr, thresholds)
  454. # set roc edge value
  455. fpr.append(1.0)
  456. tpr.append(1.0)
  457. thresholds.append(1.0)
  458. self.__save_curve_data(fpr, tpr, metric_name, metric_namespace)
  459. self.__save_curve_meta(
  460. metric_name=metric_name,
  461. metric_namespace=metric_namespace,
  462. metric_type="ROC",
  463. unit_name="fpr",
  464. ordinate_name="tpr",
  465. curve_name=data_name,
  466. thresholds=thresholds)
  467. def __save_ks_curve(
  468. self,
  469. metric,
  470. metric_res,
  471. metric_name,
  472. metric_namespace,
  473. data_name):
  474. best_ks, fpr, tpr, thresholds, cuts = metric_res[1]
  475. for curve_name, curve_data in zip(["fpr", "tpr"], [fpr, tpr]):
  476. metric_name_fpr = '_'.join([metric_name, curve_name])
  477. curve_name_fpr = "_".join([data_name, curve_name])
  478. self.__save_curve_data(
  479. cuts,
  480. curve_data,
  481. metric_name_fpr,
  482. metric_namespace)
  483. self.__save_curve_meta(
  484. metric_name=metric_name_fpr,
  485. metric_namespace=metric_namespace,
  486. metric_type=metric.upper(),
  487. unit_name="",
  488. curve_name=curve_name_fpr,
  489. pair_type=data_name,
  490. thresholds=thresholds)
  491. def __save_lift_gain_curve(
  492. self,
  493. metric,
  494. metric_res,
  495. metric_name,
  496. metric_namespace,
  497. data_name):
  498. score, cuts, thresholds = metric_res[1]
  499. score = [float(s[1]) for s in score]
  500. cuts = [float(c[1]) for c in cuts]
  501. cuts, score, idx_list = self.__filt_override_unit_ordinate_coordinate(
  502. cuts, score)
  503. thresholds = [thresholds[idx] for idx in idx_list]
  504. score.append(1.0)
  505. cuts.append(1.0)
  506. thresholds.append(0.0)
  507. self.__save_curve_data(cuts, score, metric_name, metric_namespace)
  508. self.__save_curve_meta(
  509. metric_name=metric_name,
  510. metric_namespace=metric_namespace,
  511. metric_type=metric.upper(),
  512. unit_name="",
  513. curve_name=data_name,
  514. thresholds=thresholds)
  515. def __save_accuracy_curve(
  516. self,
  517. metric,
  518. metric_res,
  519. metric_name,
  520. metric_namespace,
  521. data_name):
  522. if self.eval_type == consts.MULTY:
  523. return
  524. score, cuts, thresholds = metric_res[1]
  525. self.__save_curve_data(cuts, score, metric_name, metric_namespace)
  526. self.__save_curve_meta(
  527. metric_name=metric_name,
  528. metric_namespace=metric_namespace,
  529. metric_type=metric.upper(),
  530. unit_name="",
  531. curve_name=data_name,
  532. thresholds=thresholds)
  533. def __save_pr_curve(self, precision_and_recall, data_name):
  534. precision_res = precision_and_recall[consts.PRECISION]
  535. recall_res = precision_and_recall[consts.RECALL]
  536. if precision_res[0] != recall_res[0]:
  537. LOGGER.warning(
  538. "precision mode:{} is not equal to recall mode:{}".format(
  539. precision_res[0], recall_res[0]))
  540. return
  541. metric_namespace = precision_res[0]
  542. metric_name_precision = '_'.join([data_name, "precision"])
  543. metric_name_recall = '_'.join([data_name, "recall"])
  544. pos_precision_score = precision_res[1][0]
  545. precision_cuts = precision_res[1][1]
  546. if len(precision_res[1]) >= 3:
  547. precision_thresholds = precision_res[1][2]
  548. else:
  549. precision_thresholds = None
  550. pos_recall_score = recall_res[1][0]
  551. recall_cuts = recall_res[1][1]
  552. if len(recall_res[1]) >= 3:
  553. recall_thresholds = recall_res[1][2]
  554. else:
  555. recall_thresholds = None
  556. precision_curve_name = data_name
  557. recall_curve_name = data_name
  558. if self.eval_type == consts.BINARY:
  559. pos_precision_score = [score[1] for score in pos_precision_score]
  560. pos_recall_score = [score[1] for score in pos_recall_score]
  561. pos_recall_score, pos_precision_score, idx_list = self.__filt_override_unit_ordinate_coordinate(
  562. pos_recall_score, pos_precision_score)
  563. precision_cuts = [precision_cuts[idx] for idx in idx_list]
  564. recall_cuts = [recall_cuts[idx] for idx in idx_list]
  565. edge_idx = idx_list[-1]
  566. if edge_idx == len(precision_thresholds) - 1:
  567. idx_list = idx_list[:-1]
  568. precision_thresholds = [
  569. precision_thresholds[idx] for idx in idx_list]
  570. recall_thresholds = [recall_thresholds[idx] for idx in idx_list]
  571. elif self.eval_type == consts.MULTY:
  572. pos_recall_score, recall_cuts = self.__multi_class_label_padding(
  573. pos_recall_score, recall_cuts)
  574. pos_precision_score, precision_cuts = self.__multi_class_label_padding(
  575. pos_precision_score, precision_cuts)
  576. self.__save_curve_data(
  577. precision_cuts,
  578. pos_precision_score,
  579. metric_name_precision,
  580. metric_namespace)
  581. self.__save_curve_meta(metric_name_precision,
  582. metric_namespace,
  583. "_".join([consts.PRECISION.upper(),
  584. self.eval_type.upper()]),
  585. unit_name="",
  586. ordinate_name="Precision",
  587. curve_name=precision_curve_name,
  588. pair_type=data_name,
  589. thresholds=precision_thresholds)
  590. self.__save_curve_data(
  591. recall_cuts,
  592. pos_recall_score,
  593. metric_name_recall,
  594. metric_namespace)
  595. self.__save_curve_meta(metric_name_recall,
  596. metric_namespace,
  597. "_".join([consts.RECALL.upper(),
  598. self.eval_type.upper()]),
  599. unit_name="",
  600. ordinate_name="Recall",
  601. curve_name=recall_curve_name,
  602. pair_type=data_name,
  603. thresholds=recall_thresholds)
  604. def __save_confusion_mat_table(
  605. self,
  606. metric,
  607. confusion_mat,
  608. thresholds,
  609. metric_name,
  610. metric_namespace):
  611. extra_metas = {
  612. 'tp': list(
  613. confusion_mat['tp']), 'tn': list(
  614. confusion_mat['tn']), 'fp': list(
  615. confusion_mat['fp']), 'fn': list(
  616. confusion_mat['fn']), 'thresholds': list(
  617. np.round(
  618. thresholds, self.round_num))}
  619. self.tracker.set_metric_meta(
  620. metric_namespace,
  621. metric_name,
  622. MetricMeta(
  623. name=metric_name,
  624. metric_type=metric.upper(),
  625. extra_metas=extra_metas))
  626. def __save_f1_score_table(
  627. self,
  628. metric,
  629. f1_scores,
  630. thresholds,
  631. metric_name,
  632. metric_namespace):
  633. extra_metas = {
  634. 'f1_scores': list(
  635. np.round(
  636. f1_scores, self.round_num)), 'thresholds': list(
  637. np.round(
  638. thresholds, self.round_num))}
  639. self.tracker.set_metric_meta(
  640. metric_namespace,
  641. metric_name,
  642. MetricMeta(
  643. name=metric_name,
  644. metric_type=metric.upper(),
  645. extra_metas=extra_metas))
  646. def __save_psi_table(
  647. self,
  648. metric,
  649. metric_res,
  650. metric_name,
  651. metric_namespace):
  652. psi_scores, total_psi, expected_interval, expected_percentage, actual_interval, actual_percentage, \
  653. train_pos_perc, validate_pos_perc, intervals = metric_res[1]
  654. extra_metas = {
  655. 'psi_scores': list(
  656. np.round(
  657. psi_scores,
  658. self.round_num)),
  659. 'total_psi': round(
  660. total_psi,
  661. self.round_num),
  662. 'expected_interval': list(expected_interval),
  663. 'expected_percentage': list(expected_percentage),
  664. 'actual_interval': list(actual_interval),
  665. 'actual_percentage': list(actual_percentage),
  666. 'intervals': list(intervals),
  667. 'train_pos_perc': train_pos_perc,
  668. 'validate_pos_perc': validate_pos_perc}
  669. self.tracker.set_metric_meta(
  670. metric_namespace,
  671. metric_name,
  672. MetricMeta(
  673. name=metric_name,
  674. metric_type=metric.upper(),
  675. extra_metas=extra_metas))
  676. def __save_pr_table(
  677. self,
  678. metric,
  679. metric_res,
  680. metric_name,
  681. metric_namespace):
  682. p_scores, r_scores, score_threshold = metric_res
  683. extra_metas = {'p_scores': list(map(list, np.round(p_scores, self.round_num))),
  684. 'r_scores': list(map(list, np.round(r_scores, self.round_num))),
  685. 'thresholds': list(np.round(score_threshold, self.round_num))}
  686. self.tracker.set_metric_meta(
  687. metric_namespace,
  688. metric_name,
  689. MetricMeta(
  690. name=metric_name,
  691. metric_type=metric.upper(),
  692. extra_metas=extra_metas))
  693. def __save_contingency_matrix(
  694. self,
  695. metric,
  696. metric_res,
  697. metric_name,
  698. metric_namespace):
  699. result_array, unique_predicted_label, unique_true_label = metric_res
  700. true_labels = list(map(int, unique_true_label))
  701. predicted_label = list(map(int, unique_predicted_label))
  702. result_table = []
  703. for l_ in result_array:
  704. result_table.append(list(map(int, l_)))
  705. extra_metas = {
  706. 'true_labels': true_labels,
  707. 'predicted_labels': predicted_label,
  708. 'result_table': result_table}
  709. self.tracker.set_metric_meta(
  710. metric_namespace,
  711. metric_name,
  712. MetricMeta(
  713. name=metric_name,
  714. metric_type=metric.upper(),
  715. extra_metas=extra_metas))
  716. def __save_distance_measure(self, metric, metric_res: dict, metric_name, metric_namespace):
  717. extra_metas = {}
  718. cluster_index = [k for k in metric_res.keys()]
  719. radius, neareast_idx = [], []
  720. for k in metric_res:
  721. radius.append(metric_res[k][0])
  722. neareast_idx.append(metric_res[k][1])
  723. extra_metas['cluster_index'] = cluster_index
  724. extra_metas['radius'] = radius
  725. extra_metas['nearest_idx'] = neareast_idx
  726. self.tracker.set_metric_meta(
  727. metric_namespace,
  728. metric_name,
  729. MetricMeta(
  730. name=metric_name,
  731. metric_type=metric.upper(),
  732. extra_metas=extra_metas))
  733. def __update_summary(self, data_type, namespace, metric, metric_val):
  734. if data_type not in self.metric_summaries:
  735. self.metric_summaries[data_type] = {}
  736. if namespace not in self.metric_summaries[data_type]:
  737. self.metric_summaries[data_type][namespace] = {}
  738. self.metric_summaries[data_type][namespace][metric] = metric_val
  739. def __save_summary(self):
  740. LOGGER.info('eval summary is {}'.format(self.metric_summaries))
  741. self.set_summary(self.metric_summaries)
  742. def callback_ovr_metric_data(self, eval_results):
  743. for model_name, eval_rs in eval_results.items():
  744. train_callback_meta = defaultdict(dict)
  745. validate_callback_meta = defaultdict(dict)
  746. split_list = model_name.split('_')
  747. label = split_list[-1]
  748. # remove ' "class" label_index'
  749. origin_model_name_list = split_list[:-2]
  750. origin_model_name = ''
  751. for s in origin_model_name_list:
  752. origin_model_name += (s + '_')
  753. origin_model_name = origin_model_name[:-1]
  754. for rs_dict in eval_rs:
  755. for metric_name, metric_rs in rs_dict.items():
  756. if metric_name == consts.KS:
  757. # ks value only, curve data is not needed
  758. metric_rs = [metric_rs[0], metric_rs[1][0]]
  759. metric_namespace = metric_rs[0]
  760. if metric_namespace == 'train':
  761. callback_meta = train_callback_meta
  762. else:
  763. callback_meta = validate_callback_meta
  764. callback_meta[label][metric_name] = metric_rs[1]
  765. self.tracker.set_metric_meta(
  766. "train",
  767. model_name + '_' + 'ovr',
  768. MetricMeta(
  769. name=origin_model_name,
  770. metric_type='ovr',
  771. extra_metas=train_callback_meta))
  772. self.tracker.set_metric_meta(
  773. "validate",
  774. model_name + '_' + 'ovr',
  775. MetricMeta(
  776. name=origin_model_name,
  777. metric_type='ovr',
  778. extra_metas=validate_callback_meta))
  779. LOGGER.debug(
  780. 'callback data {} {}'.format(
  781. train_callback_meta,
  782. validate_callback_meta))
  783. def callback_metric_data(
  784. self,
  785. eval_results,
  786. return_single_val_metrics=False):
  787. # collect single val metric for validation strategy
  788. validate_metric = {}
  789. train_metric = {}
  790. collect_dict = {}
  791. for (data_type, eval_res_list) in eval_results.items():
  792. precision_recall = {}
  793. for eval_res in eval_res_list:
  794. for (metric, metric_res) in eval_res.items():
  795. metric_namespace = metric_res[0]
  796. if metric_namespace == 'validate':
  797. collect_dict = validate_metric
  798. elif metric_namespace == 'train':
  799. collect_dict = train_metric
  800. metric_name = '_'.join([data_type, metric])
  801. single_val_metric = self.__process_single_value_data(
  802. metric, metric_res)
  803. if single_val_metric is not None:
  804. self.__save_single_value(
  805. single_val_metric,
  806. metric_name=data_type,
  807. metric_namespace=metric_namespace,
  808. eval_name=metric)
  809. collect_dict[metric] = single_val_metric
  810. # update pipeline summary
  811. self.__update_summary(
  812. data_type, metric_namespace, metric, single_val_metric)
  813. if metric == consts.KS:
  814. self.__save_ks_curve(
  815. metric, metric_res, metric_name, metric_namespace, data_type)
  816. elif metric == consts.ROC:
  817. self.__save_roc_curve(
  818. data_type, metric_name, metric_namespace, metric_res[1])
  819. elif metric == consts.ACCURACY:
  820. self.__save_accuracy_curve(
  821. metric, metric_res, metric_name, metric_namespace, data_type)
  822. elif metric in [consts.GAIN, consts.LIFT]:
  823. self.__save_lift_gain_curve(
  824. metric, metric_res, metric_name, metric_namespace, data_type)
  825. elif metric in [consts.PRECISION, consts.RECALL]:
  826. precision_recall[metric] = metric_res
  827. if len(precision_recall) < 2:
  828. continue
  829. self.__save_pr_curve(precision_recall, data_type)
  830. precision_recall = {} # reset cached dict
  831. elif metric == consts.PSI:
  832. self.__save_psi_table(
  833. metric, metric_res, metric_name, metric_namespace)
  834. elif metric == consts.CONFUSION_MAT:
  835. confusion_mat, cuts, score_threshold = metric_res[1]
  836. self.__save_confusion_mat_table(
  837. metric, confusion_mat, score_threshold, metric_name, metric_namespace)
  838. elif metric == consts.F1_SCORE:
  839. f1_scores, cuts, score_threshold = metric_res[1]
  840. self.__save_f1_score_table(
  841. metric, f1_scores, score_threshold, metric_name, metric_namespace)
  842. elif metric == consts.QUANTILE_PR:
  843. self.__save_pr_table(
  844. metric, metric_res[1], metric_name, metric_namespace)
  845. elif metric == consts.CONTINGENCY_MATRIX:
  846. self.__save_contingency_matrix(
  847. metric, metric_res[1], metric_name, metric_namespace)
  848. elif metric == consts.DISTANCE_MEASURE:
  849. self.__save_distance_measure(
  850. metric, metric_res[1], metric_name, metric_namespace)
  851. self.__save_summary()
  852. if return_single_val_metrics:
  853. if len(validate_metric) != 0:
  854. LOGGER.debug("return validate metric")
  855. LOGGER.debug('validate metric is {}'.format(validate_metric))
  856. return validate_metric
  857. else:
  858. LOGGER.debug("validate metric is empty, return train metric")
  859. LOGGER.debug('train metric is {}'.format(train_metric))
  860. return train_metric
  861. else:
  862. return None
  863. @staticmethod
  864. def extract_data(data: dict):
  865. result = {}
  866. for k, v in data.items():
  867. result[".".join(k.split(".")[:1])] = v
  868. return result