evaluation_param.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.util import consts, LOGGER
  19. from federatedml.param.base_param import BaseParam
  20. class EvaluateParam(BaseParam):
  21. """
  22. Define the evaluation method of binary/multiple classification and regression
  23. Parameters
  24. ----------
  25. eval_type : {'binary', 'regression', 'multi'}
  26. support 'binary' for HomoLR, HeteroLR and Secureboosting,
  27. support 'regression' for Secureboosting,
  28. 'multi' is not support these version
  29. unfold_multi_result : bool
  30. unfold multi result and get several one-vs-rest binary classification results
  31. pos_label : int or float or str
  32. specify positive label type, depend on the data's label. this parameter effective only for 'binary'
  33. need_run: bool, default True
  34. Indicate if this module needed to be run
  35. """
  36. def __init__(self, eval_type="binary", pos_label=1, need_run=True, metrics=None,
  37. run_clustering_arbiter_metric=False, unfold_multi_result=False):
  38. super().__init__()
  39. self.eval_type = eval_type
  40. self.pos_label = pos_label
  41. self.need_run = need_run
  42. self.metrics = metrics
  43. self.unfold_multi_result = unfold_multi_result
  44. self.run_clustering_arbiter_metric = run_clustering_arbiter_metric
  45. self.default_metrics = {
  46. consts.BINARY: consts.ALL_BINARY_METRICS,
  47. consts.MULTY: consts.ALL_MULTI_METRICS,
  48. consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
  49. consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
  50. }
  51. self.allowed_metrics = {
  52. consts.BINARY: consts.ALL_BINARY_METRICS,
  53. consts.MULTY: consts.ALL_MULTI_METRICS,
  54. consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
  55. consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
  56. }
  57. def _use_single_value_default_metrics(self):
  58. self.default_metrics = {
  59. consts.BINARY: consts.DEFAULT_BINARY_METRIC,
  60. consts.MULTY: consts.DEFAULT_MULTI_METRIC,
  61. consts.REGRESSION: consts.DEFAULT_REGRESSION_METRIC,
  62. consts.CLUSTERING: consts.DEFAULT_CLUSTER_METRIC
  63. }
  64. def _check_valid_metric(self, metrics_list):
  65. metric_list = consts.ALL_METRIC_NAME
  66. alias_name: dict = consts.ALIAS
  67. full_name_list = []
  68. metrics_list = [str.lower(i) for i in metrics_list]
  69. for metric in metrics_list:
  70. if metric in metric_list:
  71. if metric not in full_name_list:
  72. full_name_list.append(metric)
  73. continue
  74. valid_flag = False
  75. for alias, full_name in alias_name.items():
  76. if metric in alias:
  77. if full_name not in full_name_list:
  78. full_name_list.append(full_name)
  79. valid_flag = True
  80. break
  81. if not valid_flag:
  82. raise ValueError('metric {} is not supported'.format(metric))
  83. allowed_metrics = self.allowed_metrics[self.eval_type]
  84. for m in full_name_list:
  85. if m not in allowed_metrics:
  86. raise ValueError('metric {} is not used for {} task'.format(m, self.eval_type))
  87. if consts.RECALL in full_name_list and consts.PRECISION not in full_name_list:
  88. full_name_list.append(consts.PRECISION)
  89. if consts.RECALL not in full_name_list and consts.PRECISION in full_name_list:
  90. full_name_list.append(consts.RECALL)
  91. return full_name_list
  92. def check(self):
  93. descr = "evaluate param's "
  94. self.eval_type = self.check_and_change_lower(self.eval_type,
  95. [consts.BINARY, consts.MULTY, consts.REGRESSION,
  96. consts.CLUSTERING],
  97. descr)
  98. if type(self.pos_label).__name__ not in ["str", "float", "int"]:
  99. raise ValueError(
  100. "evaluate param's pos_label {} not supported, should be str or float or int type".format(
  101. self.pos_label))
  102. if type(self.need_run).__name__ != "bool":
  103. raise ValueError(
  104. "evaluate param's need_run {} not supported, should be bool".format(
  105. self.need_run))
  106. if self.metrics is None or len(self.metrics) == 0:
  107. self.metrics = self.default_metrics[self.eval_type]
  108. LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
  109. self.check_boolean(self.unfold_multi_result, 'multi_result_unfold')
  110. self.metrics = self._check_valid_metric(self.metrics)
  111. return True
  112. def check_single_value_default_metric(self):
  113. self._use_single_value_default_metrics()
  114. # in validation strategy, psi f1-score and confusion-mat pr-quantile are not supported in cur version
  115. if self.metrics is None or len(self.metrics) == 0:
  116. self.metrics = self.default_metrics[self.eval_type]
  117. LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
  118. ban_metric = [consts.PSI, consts.F1_SCORE, consts.CONFUSION_MAT, consts.QUANTILE_PR]
  119. for metric in self.metrics:
  120. if metric in ban_metric:
  121. self.metrics.remove(metric)
  122. self.check()