evaluation_param.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from pipeline.param import consts
  19. from pipeline.param.base_param import BaseParam
  20. class EvaluateParam(BaseParam):
  21. """
  22. Define the evaluation method of binary/multiple classification and regression
  23. Parameters
  24. ----------
  25. eval_type: string, support 'binary' for HomoLR, HeteroLR and Secureboosting. support 'regression' for Secureboosting. 'multi' is not support these version
  26. unfold_multi_result: bool, unfold multi result and get several one-vs-rest binary classification results
  27. pos_label: specify positive label type, can be int, float and str, this depend on the data's label, this parameter effective only for 'binary'
  28. need_run: bool, default True
  29. Indicate if this module needed to be run
  30. """
  31. def __init__(self, eval_type="binary", pos_label=1, need_run=True, metrics=None,
  32. run_clustering_arbiter_metric=False, unfold_multi_result=False):
  33. super().__init__()
  34. self.eval_type = eval_type
  35. self.pos_label = pos_label
  36. self.need_run = need_run
  37. self.metrics = metrics
  38. self.unfold_multi_result = unfold_multi_result
  39. self.run_clustering_arbiter_metric = run_clustering_arbiter_metric
  40. self.default_metrics = {
  41. consts.BINARY: consts.ALL_BINARY_METRICS,
  42. consts.MULTY: consts.ALL_MULTI_METRICS,
  43. consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
  44. consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
  45. }
  46. self.allowed_metrics = {
  47. consts.BINARY: consts.ALL_BINARY_METRICS,
  48. consts.MULTY: consts.ALL_MULTI_METRICS,
  49. consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
  50. consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
  51. }
  52. def _use_single_value_default_metrics(self):
  53. self.default_metrics = {
  54. consts.BINARY: consts.DEFAULT_BINARY_METRIC,
  55. consts.MULTY: consts.DEFAULT_MULTI_METRIC,
  56. consts.REGRESSION: consts.DEFAULT_REGRESSION_METRIC,
  57. consts.CLUSTERING: consts.DEFAULT_CLUSTER_METRIC
  58. }
  59. def _check_valid_metric(self, metrics_list):
  60. metric_list = consts.ALL_METRIC_NAME
  61. alias_name: dict = consts.ALIAS
  62. full_name_list = []
  63. metrics_list = [str.lower(i) for i in metrics_list]
  64. for metric in metrics_list:
  65. if metric in metric_list:
  66. if metric not in full_name_list:
  67. full_name_list.append(metric)
  68. continue
  69. valid_flag = False
  70. for alias, full_name in alias_name.items():
  71. if metric in alias:
  72. if full_name not in full_name_list:
  73. full_name_list.append(full_name)
  74. valid_flag = True
  75. break
  76. if not valid_flag:
  77. raise ValueError('metric {} is not supported'.format(metric))
  78. allowed_metrics = self.allowed_metrics[self.eval_type]
  79. for m in full_name_list:
  80. if m not in allowed_metrics:
  81. raise ValueError('metric {} is not used for {} task'.format(m, self.eval_type))
  82. if consts.RECALL in full_name_list and consts.PRECISION not in full_name_list:
  83. full_name_list.append(consts.PRECISION)
  84. if consts.RECALL not in full_name_list and consts.PRECISION in full_name_list:
  85. full_name_list.append(consts.RECALL)
  86. return full_name_list
  87. def check(self):
  88. descr = "evaluate param's "
  89. self.eval_type = self.check_and_change_lower(self.eval_type,
  90. [consts.BINARY, consts.MULTY, consts.REGRESSION,
  91. consts.CLUSTERING],
  92. descr)
  93. if type(self.pos_label).__name__ not in ["str", "float", "int"]:
  94. raise ValueError(
  95. "evaluate param's pos_label {} not supported, should be str or float or int type".format(
  96. self.pos_label))
  97. if type(self.need_run).__name__ != "bool":
  98. raise ValueError(
  99. "evaluate param's need_run {} not supported, should be bool".format(
  100. self.need_run))
  101. if self.metrics is None or len(self.metrics) == 0:
  102. self.metrics = self.default_metrics[self.eval_type]
  103. LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
  104. self.check_boolean(self.unfold_multi_result, 'multi_result_unfold')
  105. self.metrics = self._check_valid_metric(self.metrics)
  106. LOGGER.info("Finish evaluation parameter check!")
  107. return True
  108. def check_single_value_default_metric(self):
  109. self._use_single_value_default_metrics()
  110. # in validation strategy, psi f1-score and confusion-mat pr-quantile are not supported in cur version
  111. if self.metrics is None or len(self.metrics) == 0:
  112. self.metrics = self.default_metrics[self.eval_type]
  113. LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
  114. ban_metric = [consts.PSI, consts.F1_SCORE, consts.CONFUSION_MAT, consts.QUANTILE_PR]
  115. for metric in self.metrics:
  116. if metric in ban_metric:
  117. self.metrics.remove(metric)
  118. self.check()