import unittest from federatedml.callbacks.validation_strategy import ValidationStrategy import numpy as np from federatedml.util import consts from federatedml.param.evaluation_param import EvaluateParam class TestValidationStrategy(unittest.TestCase): def setUp(self) -> None: self.role = 'guest' self.mode = 'hetero' self.early_stopping_round = 1 self.use_first_metric_only = False @staticmethod def generate_fake_eval_metrics(total_rounds, decrease_round, metrics=['ks', 'auc'], start_val=0.8): assert total_rounds >= decrease_round eval_result_list = [] start_decrease_round = total_rounds - decrease_round for i in range(total_rounds): if i < start_decrease_round: start_val += 0.01 else: start_val -= 0.01 eval_dict = {metric: start_val for metric in metrics} eval_result_list.append(eval_dict) return eval_result_list def test_early_stopping(self): test_rounds = [i for i in range(10, 100)] decrease_rounds = [np.random.randint(i) for i in test_rounds] for test_round, decrease_round in zip(test_rounds, decrease_rounds): eval_dicts = self.generate_fake_eval_metrics(test_round, decrease_round, ) self.early_stopping_round = decrease_round - 1 if self.early_stopping_round <= 0: continue validation_strategy = ValidationStrategy( self.role, self.mode, early_stopping_rounds=self.early_stopping_round, use_first_metric_only=self.use_first_metric_only) for idx, eval_res in enumerate(eval_dicts): validation_strategy.performance_recorder.update(eval_res) check_rs = validation_strategy.check_early_stopping() if check_rs: self.assertTrue((test_round - decrease_round + self.early_stopping_round - 1) == idx) print('test checking passed') break def test_use_first_metric_only(self): def evaluate(param, early_stopping_rounds, use_first_metric_only): eval_type = param.eval_type metric_list = param.metrics first_metric = None if early_stopping_rounds and use_first_metric_only and len(metric_list) != 0: single_metric_list = None if eval_type == consts.BINARY: single_metric_list = consts.BINARY_SINGLE_VALUE_METRIC elif eval_type == consts.REGRESSION: single_metric_list = consts.REGRESSION_SINGLE_VALUE_METRICS elif eval_type == consts.MULTY: single_metric_list = consts.MULTI_SINGLE_VALUE_METRIC for metric in metric_list: if metric in single_metric_list: first_metric = metric break return first_metric param_0 = EvaluateParam(metrics=['roc', 'lift', 'ks', 'auc', 'gain'], eval_type='binary') param_1 = EvaluateParam(metrics=['acc', 'precision', 'auc'], eval_type='binary') param_2 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'recall', 'lift'], eval_type='binary') param_3 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'auc', 'recall'], eval_type='multi') print(evaluate(param_0, 10, True)) print(evaluate(param_1, 10, True)) print(evaluate(param_2, 10, True)) print(evaluate(param_3, 10, True)) def test_best_iter(self): test_rounds = [i for i in range(10, 100)] decrease_rounds = [np.random.randint(i) for i in test_rounds] for test_round, decrease_round in zip(test_rounds, decrease_rounds): eval_dicts = self.generate_fake_eval_metrics(test_round, decrease_round, ) self.early_stopping_round = decrease_round - 1 if self.early_stopping_round <= 0: continue validation_strategy = ValidationStrategy(self.role, self.mode, early_stopping_rounds=self.early_stopping_round, use_first_metric_only=self.use_first_metric_only) for idx, eval_res in enumerate(eval_dicts): validation_strategy.performance_recorder.update(eval_res) check_rs = validation_strategy.check_early_stopping() if check_rs: best_perform = validation_strategy.performance_recorder.cur_best_performance self.assertDictEqual(best_perform, eval_dicts[test_round - decrease_round - 1]) print('best iter checking passed') break def test_homo_checking(self): try: validation_strategy = ValidationStrategy(self.role, mode='homo', early_stopping_rounds=1) except Exception as e: # throwing an error is expected print(e) print('error detected {}, homo checking passed'.format(e)) if __name__ == '__main__': tvs = TestValidationStrategy() tvs.setUp() tvs.test_use_first_metric_only() # tvs.test_early_stopping() # tvs.test_best_iter() # tvs.test_homo_checking() # expect checking error !!!