metric_interface.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. from sklearn.metrics import roc_auc_score
  2. from sklearn.metrics import roc_curve
  3. import numpy as np
  4. import logging
  5. from federatedml.util import consts
  6. from federatedml.evaluation.metrics import classification_metric
  7. from federatedml.evaluation.metrics import regression_metric
  8. from federatedml.evaluation.metrics import clustering_metric
  9. from functools import wraps
  10. class MetricInterface(object):
  11. def __init__(self, pos_label: int, eval_type: str):
  12. self.pos_label = pos_label
  13. self.eval_type = eval_type
  14. def auc(self, labels, pred_scores):
  15. """
  16. Compute AUC for binary classification.
  17. Parameters
  18. ----------
  19. labels: value list. The labels of data set.
  20. pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  21. Returns
  22. ----------
  23. float
  24. The AUC
  25. """
  26. if self.eval_type == consts.BINARY:
  27. return roc_auc_score(labels, pred_scores)
  28. elif self.eval_type == consts.ONE_VS_REST:
  29. try:
  30. score = roc_auc_score(labels, pred_scores)
  31. except BaseException:
  32. score = 0 # in case all labels are 0 or 1
  33. logging.warning("all true labels are 0/1 when running ovr AUC")
  34. return score
  35. else:
  36. logging.warning(
  37. "auc is just suppose Binary Classification! return None as results")
  38. return None
  39. @staticmethod
  40. def explained_variance(labels, pred_scores):
  41. """
  42. Compute explain variance
  43. Parameters
  44. ----------
  45. labels: value list. The labels of data set.
  46. pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  47. Returns
  48. ----------
  49. float
  50. The explain variance
  51. """
  52. return regression_metric.ExplainedVariance().compute(labels, pred_scores)
  53. @staticmethod
  54. def mean_absolute_error(labels, pred_scores):
  55. """
  56. Compute mean absolute error
  57. Parameters
  58. ----------
  59. labels: value list. The labels of data set.
  60. pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  61. Returns
  62. ----------
  63. float
  64. A non-negative floating point.
  65. """
  66. return regression_metric.MAE().compute(labels, pred_scores)
  67. @staticmethod
  68. def mean_squared_error(labels, pred_scores):
  69. """
  70. Compute mean square error
  71. Parameters
  72. ----------
  73. labels: value list. The labels of data set.
  74. pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  75. Returns
  76. ----------
  77. float
  78. A non-negative floating point value
  79. """
  80. return regression_metric.MSE.compute(labels, pred_scores)
  81. @staticmethod
  82. def median_absolute_error(labels, pred_scores):
  83. """
  84. Compute median absolute error
  85. Parameters
  86. ----------
  87. labels: value list. The labels of data set.
  88. pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  89. Returns
  90. ----------
  91. float
  92. A positive floating point value
  93. """
  94. return regression_metric.MedianAbsoluteError().compute(labels, pred_scores)
  95. @staticmethod
  96. def r2_score(labels, pred_scores):
  97. """
  98. Compute R^2 (coefficient of determination) score
  99. Parameters
  100. ----------
  101. labels: value list. The labels of data set.
  102. pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  103. Returns
  104. ----------
  105. float
  106. The R^2 score
  107. """
  108. return regression_metric.R2Score().compute(labels, pred_scores)
  109. @staticmethod
  110. def root_mean_squared_error(labels, pred_scores):
  111. """
  112. Compute the root of mean square error
  113. Parameters
  114. ----------
  115. labels: value list. The labels of data set.
  116. pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  117. Return
  118. ----------
  119. float
  120. A positive floating point value
  121. """
  122. return regression_metric.RMSE.compute(labels, pred_scores)
  123. @staticmethod
  124. def __to_int_list(array: np.ndarray):
  125. return list(map(int, list(array)))
  126. @staticmethod
  127. def __filt_threshold(thresholds, step):
  128. cuts = list(map(float, np.arange(0, 1, step)))
  129. size = len(list(thresholds))
  130. thresholds.sort(reverse=True)
  131. index_list = [int(size * cut) for cut in cuts]
  132. new_thresholds = [thresholds[idx] for idx in index_list]
  133. return new_thresholds, cuts
  134. def roc(self, labels, pred_scores):
  135. if self.eval_type == consts.BINARY:
  136. fpr, tpr, thresholds = roc_curve(
  137. np.array(labels), np.array(pred_scores), drop_intermediate=1)
  138. fpr, tpr, thresholds = list(map(float, fpr)), list(
  139. map(float, tpr)), list(map(float, thresholds))
  140. filt_thresholds, cuts = self.__filt_threshold(
  141. thresholds=thresholds, step=0.01)
  142. new_thresholds = []
  143. new_tpr = []
  144. new_fpr = []
  145. for threshold in filt_thresholds:
  146. index = thresholds.index(threshold)
  147. new_tpr.append(tpr[index])
  148. new_fpr.append(fpr[index])
  149. new_thresholds.append(threshold)
  150. fpr = new_fpr
  151. tpr = new_tpr
  152. thresholds = new_thresholds
  153. return fpr, tpr, thresholds, cuts
  154. else:
  155. logging.warning(
  156. "roc_curve is just suppose Binary Classification! return None as results")
  157. fpr, tpr, thresholds, cuts = None, None, None, None
  158. return fpr, tpr, thresholds, cuts
  159. def ks(self, labels, pred_scores):
  160. """
  161. Compute Kolmogorov-Smirnov
  162. Parameters
  163. ----------
  164. labels: value list. The labels of data set.
  165. pred_scores: pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  166. Returns
  167. ----------
  168. max_ks_interval: float max value of each tpr - fpt
  169. fpr:
  170. """
  171. if self.eval_type == consts.ONE_VS_REST:
  172. try:
  173. rs = classification_metric.KS().compute(labels, pred_scores)
  174. except BaseException:
  175. rs = [0, [0], [0], [0], [0]] # in case all labels are 0 or 1
  176. logging.warning("all true labels are 0/1 when running ovr KS")
  177. return rs
  178. else:
  179. return classification_metric.KS().compute(labels, pred_scores)
  180. def lift(self, labels, pred_scores):
  181. """
  182. Compute lift of binary classification.
  183. Parameters
  184. ----------
  185. labels: value list. The labels of data set.
  186. pred_scores: pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  187. thresholds: value list. This parameter effective only for 'binary'. The predict scores will be 1 if it larger than thresholds, if not,
  188. if will be 0. If not only one threshold in it, it will return several results according to the thresholds. default None
  189. Returns
  190. ----------
  191. float
  192. The lift
  193. """
  194. if self.eval_type == consts.BINARY:
  195. return classification_metric.Lift().compute(labels, pred_scores)
  196. else:
  197. logging.warning(
  198. "lift is just suppose Binary Classification! return None as results")
  199. return None
  200. def gain(self, labels, pred_scores):
  201. """
  202. Compute gain of binary classification.
  203. Parameters
  204. ----------
  205. labels: value list. The labels of data set.
  206. pred_scores: pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  207. thresholds: value list. This parameter effective only for 'binary'. The predict scores will be 1 if it larger than thresholds, if not,
  208. if will be 0. If not only one threshold in it, it will return several results according to the thresholds. default None
  209. Returns
  210. ----------
  211. float
  212. The gain
  213. """
  214. if self.eval_type == consts.BINARY:
  215. return classification_metric.Gain().compute(labels, pred_scores)
  216. else:
  217. logging.warning(
  218. "gain is just suppose Binary Classification! return None as results")
  219. return None
  220. def precision(self, labels, pred_scores):
  221. """
  222. Compute the precision
  223. Parameters
  224. ----------
  225. labels: value list. The labels of data set.
  226. pred_scores: pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  227. thresholds: value list. This parameter effective only for 'binary'. The predict scores will be 1 if it larger than thresholds, if not,
  228. if will be 0. If not only one threshold in it, it will return several results according to the thresholds. default None
  229. result_filter: value list. If result_filter is not None, it will filter the label results not in result_filter.
  230. Returns
  231. ----------
  232. dict
  233. The key is threshold and the value is another dic, which key is label in parameter labels, and value is the label's precision.
  234. """
  235. if self.eval_type == consts.BINARY:
  236. precision_operator = classification_metric.BiClassPrecision()
  237. metric_scores, score_threshold, cuts = precision_operator.compute(
  238. labels, pred_scores)
  239. return metric_scores, cuts, score_threshold
  240. elif self.eval_type == consts.MULTY:
  241. precision_operator = classification_metric.MultiClassPrecision()
  242. return precision_operator.compute(labels, pred_scores)
  243. else:
  244. logging.warning(
  245. "error:can not find classification type:{}".format(
  246. self.eval_type))
  247. def recall(self, labels, pred_scores):
  248. """
  249. Compute the recall
  250. Parameters
  251. ----------
  252. labels: value list. The labels of data set.
  253. pred_scores: pred_scores: value list. The predict results of model. It should be corresponding to labels each
  254. data.
  255. Returns
  256. ----------
  257. dict
  258. The key is threshold and the value is another dic, which key is label in parameter labels, and value is the
  259. label's recall.
  260. """
  261. if self.eval_type == consts.BINARY:
  262. recall_operator = classification_metric.BiClassRecall()
  263. recall_res, thresholds, cuts = recall_operator.compute(
  264. labels, pred_scores)
  265. return recall_res, cuts, thresholds
  266. elif self.eval_type == consts.MULTY:
  267. recall_operator = classification_metric.MultiClassRecall()
  268. return recall_operator.compute(labels, pred_scores)
  269. else:
  270. logging.warning(
  271. "error:can not find classification type:{}".format(
  272. self.eval_type))
  273. def accuracy(self, labels, pred_scores, normalize=True):
  274. """
  275. Compute the accuracy
  276. Parameters
  277. ----------
  278. labels: value list. The labels of data set.
  279. pred_scores: pred_scores: value list. The predict results of model. It should be corresponding to labels each data.
  280. normalize: bool. If true, return the fraction of correctly classified samples, else returns the number of correctly classified samples
  281. Returns
  282. ----------
  283. dict
  284. the key is threshold and the value is the accuracy of this threshold.
  285. """
  286. if self.eval_type == consts.BINARY:
  287. acc_operator = classification_metric.BiClassAccuracy()
  288. acc_res, thresholds, cuts = acc_operator.compute(
  289. labels, pred_scores, normalize)
  290. return acc_res, cuts, thresholds
  291. elif self.eval_type == consts.MULTY:
  292. acc_operator = classification_metric.MultiClassAccuracy()
  293. return acc_operator.compute(labels, pred_scores, normalize)
  294. else:
  295. logging.warning(
  296. "error:can not find classification type:".format(
  297. self.eval_type))
  298. def f1_score(self, labels, pred_scores):
  299. """
  300. compute f1_score for binary classification result
  301. """
  302. if self.eval_type == consts.BINARY:
  303. f1_scores, score_threshold, cuts = classification_metric.FScore().compute(labels,
  304. pred_scores)
  305. return list(f1_scores), list(cuts), list(score_threshold)
  306. else:
  307. logging.warning(
  308. 'error: f-score metric is for binary classification only')
  309. def confusion_mat(self, labels, pred_scores):
  310. """
  311. compute confusion matrix
  312. """
  313. if self.eval_type == consts.BINARY:
  314. sorted_labels, sorted_scores = classification_metric.sort_score_and_label(
  315. labels, pred_scores)
  316. _, cuts = classification_metric.ThresholdCutter.cut_by_step(
  317. sorted_scores, steps=0.01)
  318. fixed_interval_threshold = classification_metric.ThresholdCutter.fixed_interval_threshold()
  319. confusion_mat = classification_metric.ConfusionMatrix.compute(
  320. sorted_labels, sorted_scores, fixed_interval_threshold, ret=[
  321. 'tp', 'fp', 'fn', 'tn'])
  322. confusion_mat['tp'] = self.__to_int_list(confusion_mat['tp'])
  323. confusion_mat['fp'] = self.__to_int_list(confusion_mat['fp'])
  324. confusion_mat['fn'] = self.__to_int_list(confusion_mat['fn'])
  325. confusion_mat['tn'] = self.__to_int_list(confusion_mat['tn'])
  326. return confusion_mat, cuts, fixed_interval_threshold
  327. else:
  328. logging.warning(
  329. 'error: f-score metric is for binary classification only')
  330. def psi(
  331. self,
  332. train_scores,
  333. validate_scores,
  334. train_labels,
  335. validate_labels,
  336. debug=False):
  337. """
  338. Compute the PSI index
  339. Parameters
  340. ----------
  341. train_scores: The predict results of train data
  342. validate_scores: The predict results of validate data
  343. train_labels: labels of train set
  344. validate_labels: labels of validate set
  345. debug: print additional info
  346. """
  347. if self.eval_type == consts.BINARY:
  348. psi_computer = classification_metric.PSI()
  349. psi_scores, total_psi, expected_interval, expected_percentage, actual_interval, actual_percentage, \
  350. train_pos_perc, validate_pos_perc, intervals = psi_computer.compute(train_scores, validate_scores,
  351. debug=debug, str_intervals=True,
  352. round_num=6, train_labels=train_labels, validate_labels=validate_labels)
  353. len_list = np.array([len(psi_scores),
  354. len(expected_interval),
  355. len(expected_percentage),
  356. len(actual_interval),
  357. len(actual_percentage),
  358. len(intervals)])
  359. assert (len_list == len(psi_scores)).all()
  360. return list(psi_scores), total_psi, self.__to_int_list(expected_interval), list(expected_percentage), \
  361. self.__to_int_list(actual_interval), list(actual_percentage), list(train_pos_perc), \
  362. list(validate_pos_perc), intervals
  363. else:
  364. logging.warning(
  365. 'error: psi metric is for binary classification only')
  366. def quantile_pr(self, labels, pred_scores):
  367. if self.eval_type == consts.BINARY:
  368. p = classification_metric.BiClassPrecision(
  369. cut_method='quantile', remove_duplicate=False)
  370. r = classification_metric.BiClassRecall(
  371. cut_method='quantile', remove_duplicate=False)
  372. p_scores, score_threshold, cuts = p.compute(labels, pred_scores)
  373. r_scores, score_threshold, cuts = r.compute(labels, pred_scores)
  374. p_scores = list(map(list, np.flip(p_scores, axis=0)))
  375. r_scores = list(map(list, np.flip(r_scores, axis=0)))
  376. score_threshold = list(np.flip(score_threshold))
  377. return p_scores, r_scores, score_threshold
  378. else:
  379. logging.warning(
  380. 'error: pr quantile is for binary classification only')
  381. @staticmethod
  382. def jaccard_similarity_score(labels, pred_labels):
  383. """
  384. Compute the Jaccard similarity score
  385. Parameters
  386. ----------
  387. labels: value list. The labels of data set.
  388. pred_labels: value list. The predict results of model. It should be corresponding to labels each data.
  389. Return
  390. ----------
  391. float
  392. A positive floating point value
  393. """
  394. return clustering_metric.JaccardSimilarityScore().compute(labels, pred_labels)
  395. @staticmethod
  396. def fowlkes_mallows_score(labels, pred_labels):
  397. """
  398. Compute the Fowlkes Mallows score
  399. Parameters
  400. ----------
  401. labels: value list. The labels of data set.
  402. pred_labels: value list. The predict results of model. It should be corresponding to labels each data.
  403. Return
  404. ----------
  405. float
  406. A positive floating point value
  407. """
  408. return clustering_metric.FowlkesMallowsScore().compute(labels, pred_labels)
  409. @staticmethod
  410. def adjusted_rand_score(labels, pred_labels):
  411. """
  412. Compute the adjusted-rand score
  413. Parameters
  414. ----------
  415. labels: value list. The labels of data set.
  416. pred_labels: value list. The predict results of model. It should be corresponding to labels each data.
  417. Return
  418. ----------
  419. float
  420. A positive floating point value
  421. """
  422. return clustering_metric.AdjustedRandScore().compute(labels, pred_labels)
  423. @staticmethod
  424. def davies_bouldin_index(cluster_avg_intra_dist, cluster_inter_dist):
  425. """
  426. Compute the davies_bouldin_index
  427. Parameters
  428. """
  429. # process data from evaluation
  430. return clustering_metric.DaviesBouldinIndex().compute(
  431. cluster_avg_intra_dist, cluster_inter_dist)
  432. @staticmethod
  433. def contingency_matrix(labels, pred_labels):
  434. """
  435. """
  436. return clustering_metric.ContengincyMatrix().compute(labels, pred_labels)
  437. @staticmethod
  438. def distance_measure(
  439. cluster_avg_intra_dist,
  440. cluster_inter_dist,
  441. max_radius):
  442. """
  443. """
  444. return clustering_metric.DistanceMeasure().compute(
  445. cluster_avg_intra_dist, cluster_inter_dist, max_radius)