123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from federatedml.util import consts, LOGGER
- from federatedml.param.base_param import BaseParam
- class EvaluateParam(BaseParam):
- """
- Define the evaluation method of binary/multiple classification and regression
- Parameters
- ----------
- eval_type : {'binary', 'regression', 'multi'}
- support 'binary' for HomoLR, HeteroLR and Secureboosting,
- support 'regression' for Secureboosting,
- 'multi' is not support these version
- unfold_multi_result : bool
- unfold multi result and get several one-vs-rest binary classification results
- pos_label : int or float or str
- specify positive label type, depend on the data's label. this parameter effective only for 'binary'
- need_run: bool, default True
- Indicate if this module needed to be run
- """
- def __init__(self, eval_type="binary", pos_label=1, need_run=True, metrics=None,
- run_clustering_arbiter_metric=False, unfold_multi_result=False):
- super().__init__()
- self.eval_type = eval_type
- self.pos_label = pos_label
- self.need_run = need_run
- self.metrics = metrics
- self.unfold_multi_result = unfold_multi_result
- self.run_clustering_arbiter_metric = run_clustering_arbiter_metric
- self.default_metrics = {
- consts.BINARY: consts.ALL_BINARY_METRICS,
- consts.MULTY: consts.ALL_MULTI_METRICS,
- consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
- consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
- }
- self.allowed_metrics = {
- consts.BINARY: consts.ALL_BINARY_METRICS,
- consts.MULTY: consts.ALL_MULTI_METRICS,
- consts.REGRESSION: consts.ALL_REGRESSION_METRICS,
- consts.CLUSTERING: consts.ALL_CLUSTER_METRICS
- }
- def _use_single_value_default_metrics(self):
- self.default_metrics = {
- consts.BINARY: consts.DEFAULT_BINARY_METRIC,
- consts.MULTY: consts.DEFAULT_MULTI_METRIC,
- consts.REGRESSION: consts.DEFAULT_REGRESSION_METRIC,
- consts.CLUSTERING: consts.DEFAULT_CLUSTER_METRIC
- }
- def _check_valid_metric(self, metrics_list):
- metric_list = consts.ALL_METRIC_NAME
- alias_name: dict = consts.ALIAS
- full_name_list = []
- metrics_list = [str.lower(i) for i in metrics_list]
- for metric in metrics_list:
- if metric in metric_list:
- if metric not in full_name_list:
- full_name_list.append(metric)
- continue
- valid_flag = False
- for alias, full_name in alias_name.items():
- if metric in alias:
- if full_name not in full_name_list:
- full_name_list.append(full_name)
- valid_flag = True
- break
- if not valid_flag:
- raise ValueError('metric {} is not supported'.format(metric))
- allowed_metrics = self.allowed_metrics[self.eval_type]
- for m in full_name_list:
- if m not in allowed_metrics:
- raise ValueError('metric {} is not used for {} task'.format(m, self.eval_type))
- if consts.RECALL in full_name_list and consts.PRECISION not in full_name_list:
- full_name_list.append(consts.PRECISION)
- if consts.RECALL not in full_name_list and consts.PRECISION in full_name_list:
- full_name_list.append(consts.RECALL)
- return full_name_list
- def check(self):
- descr = "evaluate param's "
- self.eval_type = self.check_and_change_lower(self.eval_type,
- [consts.BINARY, consts.MULTY, consts.REGRESSION,
- consts.CLUSTERING],
- descr)
- if type(self.pos_label).__name__ not in ["str", "float", "int"]:
- raise ValueError(
- "evaluate param's pos_label {} not supported, should be str or float or int type".format(
- self.pos_label))
- if type(self.need_run).__name__ != "bool":
- raise ValueError(
- "evaluate param's need_run {} not supported, should be bool".format(
- self.need_run))
- if self.metrics is None or len(self.metrics) == 0:
- self.metrics = self.default_metrics[self.eval_type]
- LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
- self.check_boolean(self.unfold_multi_result, 'multi_result_unfold')
- self.metrics = self._check_valid_metric(self.metrics)
- return True
- def check_single_value_default_metric(self):
- self._use_single_value_default_metrics()
- # in validation strategy, psi f1-score and confusion-mat pr-quantile are not supported in cur version
- if self.metrics is None or len(self.metrics) == 0:
- self.metrics = self.default_metrics[self.eval_type]
- LOGGER.warning('use default metric {} for eval type {}'.format(self.metrics, self.eval_type))
- ban_metric = [consts.PSI, consts.F1_SCORE, consts.CONFUSION_MAT, consts.QUANTILE_PR]
- for metric in self.metrics:
- if metric in ban_metric:
- self.metrics.remove(metric)
- self.check()
|