123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- import unittest
- from federatedml.callbacks.validation_strategy import ValidationStrategy
- import numpy as np
- from federatedml.util import consts
- from federatedml.param.evaluation_param import EvaluateParam
- class TestValidationStrategy(unittest.TestCase):
- def setUp(self) -> None:
- self.role = 'guest'
- self.mode = 'hetero'
- self.early_stopping_round = 1
- self.use_first_metric_only = False
- @staticmethod
- def generate_fake_eval_metrics(total_rounds, decrease_round, metrics=['ks', 'auc'], start_val=0.8):
- assert total_rounds >= decrease_round
- eval_result_list = []
- start_decrease_round = total_rounds - decrease_round
- for i in range(total_rounds):
- if i < start_decrease_round:
- start_val += 0.01
- else:
- start_val -= 0.01
- eval_dict = {metric: start_val for metric in metrics}
- eval_result_list.append(eval_dict)
- return eval_result_list
- def test_early_stopping(self):
- test_rounds = [i for i in range(10, 100)]
- decrease_rounds = [np.random.randint(i) for i in test_rounds]
- for test_round, decrease_round in zip(test_rounds, decrease_rounds):
- eval_dicts = self.generate_fake_eval_metrics(test_round, decrease_round, )
- self.early_stopping_round = decrease_round - 1
- if self.early_stopping_round <= 0:
- continue
- validation_strategy = ValidationStrategy(
- self.role,
- self.mode,
- early_stopping_rounds=self.early_stopping_round,
- use_first_metric_only=self.use_first_metric_only)
- for idx, eval_res in enumerate(eval_dicts):
- validation_strategy.performance_recorder.update(eval_res)
- check_rs = validation_strategy.check_early_stopping()
- if check_rs:
- self.assertTrue((test_round - decrease_round + self.early_stopping_round - 1) == idx)
- print('test checking passed')
- break
- def test_use_first_metric_only(self):
- def evaluate(param, early_stopping_rounds, use_first_metric_only):
- eval_type = param.eval_type
- metric_list = param.metrics
- first_metric = None
- if early_stopping_rounds and use_first_metric_only and len(metric_list) != 0:
- single_metric_list = None
- if eval_type == consts.BINARY:
- single_metric_list = consts.BINARY_SINGLE_VALUE_METRIC
- elif eval_type == consts.REGRESSION:
- single_metric_list = consts.REGRESSION_SINGLE_VALUE_METRICS
- elif eval_type == consts.MULTY:
- single_metric_list = consts.MULTI_SINGLE_VALUE_METRIC
- for metric in metric_list:
- if metric in single_metric_list:
- first_metric = metric
- break
- return first_metric
- param_0 = EvaluateParam(metrics=['roc', 'lift', 'ks', 'auc', 'gain'], eval_type='binary')
- param_1 = EvaluateParam(metrics=['acc', 'precision', 'auc'], eval_type='binary')
- param_2 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'recall', 'lift'], eval_type='binary')
- param_3 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'auc', 'recall'], eval_type='multi')
- print(evaluate(param_0, 10, True))
- print(evaluate(param_1, 10, True))
- print(evaluate(param_2, 10, True))
- print(evaluate(param_3, 10, True))
- def test_best_iter(self):
- test_rounds = [i for i in range(10, 100)]
- decrease_rounds = [np.random.randint(i) for i in test_rounds]
- for test_round, decrease_round in zip(test_rounds, decrease_rounds):
- eval_dicts = self.generate_fake_eval_metrics(test_round, decrease_round, )
- self.early_stopping_round = decrease_round - 1
- if self.early_stopping_round <= 0:
- continue
- validation_strategy = ValidationStrategy(self.role, self.mode,
- early_stopping_rounds=self.early_stopping_round,
- use_first_metric_only=self.use_first_metric_only)
- for idx, eval_res in enumerate(eval_dicts):
- validation_strategy.performance_recorder.update(eval_res)
- check_rs = validation_strategy.check_early_stopping()
- if check_rs:
- best_perform = validation_strategy.performance_recorder.cur_best_performance
- self.assertDictEqual(best_perform, eval_dicts[test_round - decrease_round - 1])
- print('best iter checking passed')
- break
- def test_homo_checking(self):
- try:
- validation_strategy = ValidationStrategy(self.role, mode='homo',
- early_stopping_rounds=1)
- except Exception as e:
- # throwing an error is expected
- print(e)
- print('error detected {}, homo checking passed'.format(e))
- if __name__ == '__main__':
- tvs = TestValidationStrategy()
- tvs.setUp()
- tvs.test_use_first_metric_only()
- # tvs.test_early_stopping()
- # tvs.test_best_iter()
- # tvs.test_homo_checking() # expect checking error !!!
|