early_stop_test.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import unittest
  2. from federatedml.callbacks.validation_strategy import ValidationStrategy
  3. import numpy as np
  4. from federatedml.util import consts
  5. from federatedml.param.evaluation_param import EvaluateParam
  6. class TestValidationStrategy(unittest.TestCase):
  7. def setUp(self) -> None:
  8. self.role = 'guest'
  9. self.mode = 'hetero'
  10. self.early_stopping_round = 1
  11. self.use_first_metric_only = False
  12. @staticmethod
  13. def generate_fake_eval_metrics(total_rounds, decrease_round, metrics=['ks', 'auc'], start_val=0.8):
  14. assert total_rounds >= decrease_round
  15. eval_result_list = []
  16. start_decrease_round = total_rounds - decrease_round
  17. for i in range(total_rounds):
  18. if i < start_decrease_round:
  19. start_val += 0.01
  20. else:
  21. start_val -= 0.01
  22. eval_dict = {metric: start_val for metric in metrics}
  23. eval_result_list.append(eval_dict)
  24. return eval_result_list
  25. def test_early_stopping(self):
  26. test_rounds = [i for i in range(10, 100)]
  27. decrease_rounds = [np.random.randint(i) for i in test_rounds]
  28. for test_round, decrease_round in zip(test_rounds, decrease_rounds):
  29. eval_dicts = self.generate_fake_eval_metrics(test_round, decrease_round, )
  30. self.early_stopping_round = decrease_round - 1
  31. if self.early_stopping_round <= 0:
  32. continue
  33. validation_strategy = ValidationStrategy(
  34. self.role,
  35. self.mode,
  36. early_stopping_rounds=self.early_stopping_round,
  37. use_first_metric_only=self.use_first_metric_only)
  38. for idx, eval_res in enumerate(eval_dicts):
  39. validation_strategy.performance_recorder.update(eval_res)
  40. check_rs = validation_strategy.check_early_stopping()
  41. if check_rs:
  42. self.assertTrue((test_round - decrease_round + self.early_stopping_round - 1) == idx)
  43. print('test checking passed')
  44. break
  45. def test_use_first_metric_only(self):
  46. def evaluate(param, early_stopping_rounds, use_first_metric_only):
  47. eval_type = param.eval_type
  48. metric_list = param.metrics
  49. first_metric = None
  50. if early_stopping_rounds and use_first_metric_only and len(metric_list) != 0:
  51. single_metric_list = None
  52. if eval_type == consts.BINARY:
  53. single_metric_list = consts.BINARY_SINGLE_VALUE_METRIC
  54. elif eval_type == consts.REGRESSION:
  55. single_metric_list = consts.REGRESSION_SINGLE_VALUE_METRICS
  56. elif eval_type == consts.MULTY:
  57. single_metric_list = consts.MULTI_SINGLE_VALUE_METRIC
  58. for metric in metric_list:
  59. if metric in single_metric_list:
  60. first_metric = metric
  61. break
  62. return first_metric
  63. param_0 = EvaluateParam(metrics=['roc', 'lift', 'ks', 'auc', 'gain'], eval_type='binary')
  64. param_1 = EvaluateParam(metrics=['acc', 'precision', 'auc'], eval_type='binary')
  65. param_2 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'recall', 'lift'], eval_type='binary')
  66. param_3 = EvaluateParam(metrics=['acc', 'precision', 'gain', 'auc', 'recall'], eval_type='multi')
  67. print(evaluate(param_0, 10, True))
  68. print(evaluate(param_1, 10, True))
  69. print(evaluate(param_2, 10, True))
  70. print(evaluate(param_3, 10, True))
  71. def test_best_iter(self):
  72. test_rounds = [i for i in range(10, 100)]
  73. decrease_rounds = [np.random.randint(i) for i in test_rounds]
  74. for test_round, decrease_round in zip(test_rounds, decrease_rounds):
  75. eval_dicts = self.generate_fake_eval_metrics(test_round, decrease_round, )
  76. self.early_stopping_round = decrease_round - 1
  77. if self.early_stopping_round <= 0:
  78. continue
  79. validation_strategy = ValidationStrategy(self.role, self.mode,
  80. early_stopping_rounds=self.early_stopping_round,
  81. use_first_metric_only=self.use_first_metric_only)
  82. for idx, eval_res in enumerate(eval_dicts):
  83. validation_strategy.performance_recorder.update(eval_res)
  84. check_rs = validation_strategy.check_early_stopping()
  85. if check_rs:
  86. best_perform = validation_strategy.performance_recorder.cur_best_performance
  87. self.assertDictEqual(best_perform, eval_dicts[test_round - decrease_round - 1])
  88. print('best iter checking passed')
  89. break
  90. def test_homo_checking(self):
  91. try:
  92. validation_strategy = ValidationStrategy(self.role, mode='homo',
  93. early_stopping_rounds=1)
  94. except Exception as e:
  95. # throwing an error is expected
  96. print(e)
  97. print('error detected {}, homo checking passed'.format(e))
  98. if __name__ == '__main__':
  99. tvs = TestValidationStrategy()
  100. tvs.setUp()
  101. tvs.test_use_first_metric_only()
  102. # tvs.test_early_stopping()
  103. # tvs.test_best_iter()
  104. # tvs.test_homo_checking() # expect checking error !!!