test_evaluation_module.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import unittest
  2. import numpy as np
  3. from federatedml.util import consts
  4. from federatedml.evaluation.metrics import classification_metric, clustering_metric, regression_metric
  5. from federatedml.evaluation.metric_interface import MetricInterface
  6. class TestEvaluation(unittest.TestCase):
  7. def setUp(self):
  8. self.bin_score = np.random.random(100)
  9. self.bin_label = (self.bin_score > 0.5) + 0
  10. self.reg_score = np.random.random(100) * 10
  11. self.reg_label = np.random.random(100) * 10
  12. self.multi_score = np.random.randint([4 for i in range(50)])
  13. self.multi_label = np.random.randint([4 for i in range(50)])
  14. self.clustering_score = np.random.randint([4 for i in range(50)])
  15. self.clustering_label = np.random.randint([3 for i in range(50)])
  16. self.psi_train_score = np.random.random(10000)
  17. self.psi_train_label = (self.psi_train_score > 0.5) + 0
  18. self.psi_val_score = np.random.random(1000)
  19. self.psi_val_label = (self.psi_val_score > 0.5) + 0
  20. def test_regression(self):
  21. print('testing regression metric')
  22. regression_metric.R2Score().compute(self.reg_score, self.reg_label)
  23. regression_metric.MSE().compute(self.reg_score, self.reg_label)
  24. regression_metric.RMSE().compute(self.reg_score, self.reg_label)
  25. regression_metric.ExplainedVariance().compute(self.reg_score, self.reg_label)
  26. regression_metric.Describe().compute(self.reg_score)
  27. def test_binary(self):
  28. print('testing binary')
  29. interface = MetricInterface(pos_label=1, eval_type=consts.BINARY)
  30. interface.auc(self.bin_label, self.bin_score)
  31. interface.confusion_mat(self.bin_label, self.bin_score)
  32. interface.ks(self.bin_label, self.bin_score)
  33. interface.accuracy(self.bin_label, self.bin_score)
  34. interface.f1_score(self.bin_label, self.bin_score)
  35. interface.gain(self.bin_label, self.bin_score)
  36. interface.lift(self.bin_label, self.bin_score)
  37. interface.quantile_pr(self.bin_label, self.bin_score)
  38. interface.precision(self.bin_label, self.bin_score)
  39. interface.recall(self.bin_label, self.bin_score)
  40. interface.roc(self.bin_label, self.bin_score)
  41. def test_psi(self):
  42. interface = MetricInterface(pos_label=1, eval_type=consts.BINARY)
  43. interface.psi(
  44. self.psi_train_score,
  45. self.psi_val_score,
  46. train_labels=self.psi_train_label,
  47. validate_labels=self.psi_val_label)
  48. def test_multi(self):
  49. print('testing multi')
  50. interface = MetricInterface(eval_type=consts.MULTY, pos_label=1)
  51. interface.precision(self.multi_label, self.multi_score)
  52. interface.recall(self.multi_label, self.multi_score)
  53. interface.accuracy(self.multi_label, self.multi_score)
  54. def test_clustering(self):
  55. print('testing clustering')
  56. interface = MetricInterface(eval_type=consts.CLUSTERING, pos_label=1)
  57. interface.confusion_mat(self.clustering_label, self.clustering_score)
  58. def test_newly_added(self):
  59. print('testing newly added')
  60. binary_data = list(
  61. zip([i for i in range(len(self.psi_train_score))], self.psi_train_score))
  62. classification_metric.Distribution().compute(binary_data, binary_data)
  63. multi_data = list(
  64. zip([i for i in range(len(self.multi_score))], self.multi_score))
  65. classification_metric.Distribution().compute(multi_data, multi_data)
  66. classification_metric.KSTest().compute(self.multi_score, self.multi_score)
  67. classification_metric.KSTest().compute(
  68. self.psi_train_score, self.psi_val_score)
  69. classification_metric.AveragePrecisionScore().compute(
  70. self.psi_train_score,
  71. self.psi_val_score,
  72. self.psi_train_label,
  73. self.psi_val_label)
  74. if __name__ == '__main__':
  75. unittest.main()