performance_recorder.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from federatedml.util import consts
  2. class PerformanceRecorder(object):
  3. """
  4. This class record performance(single value metrics during the training process)
  5. """
  6. def __init__(self):
  7. # all of them are single value metrics
  8. self.allowed_metric = [consts.AUC,
  9. consts.EXPLAINED_VARIANCE,
  10. consts.MEAN_ABSOLUTE_ERROR,
  11. consts.MEAN_SQUARED_ERROR,
  12. consts.MEAN_SQUARED_LOG_ERROR,
  13. consts.MEDIAN_ABSOLUTE_ERROR,
  14. consts.R2_SCORE,
  15. consts.ROOT_MEAN_SQUARED_ERROR,
  16. consts.PRECISION,
  17. consts.RECALL,
  18. consts.ACCURACY,
  19. consts.KS
  20. ]
  21. self.larger_is_better = [consts.AUC,
  22. consts.R2_SCORE,
  23. consts.PRECISION,
  24. consts.RECALL,
  25. consts.EXPLAINED_VARIANCE,
  26. consts.ACCURACY,
  27. consts.KS
  28. ]
  29. self.smaller_is_better = [consts.ROOT_MEAN_SQUARED_ERROR,
  30. consts.MEAN_ABSOLUTE_ERROR,
  31. consts.MEAN_SQUARED_ERROR,
  32. consts.MEAN_SQUARED_LOG_ERROR]
  33. self.cur_best_performance = {}
  34. self.no_improvement_round = {} # record no improvement round of all metrics
  35. def has_improved(self, val: float, metric: str, cur_best: dict):
  36. if metric not in cur_best:
  37. return True
  38. if metric in self.larger_is_better and val > cur_best[metric]:
  39. return True
  40. elif metric in self.smaller_is_better and val < cur_best[metric]:
  41. return True
  42. return False
  43. def update(self, eval_dict: dict):
  44. """
  45. Parameters
  46. ----------
  47. eval_dict dict, {metric_name:metric_val}, e.g. {'auc':0.99}
  48. Returns stop flag, if should stop return True, else False
  49. -------
  50. """
  51. if len(eval_dict) == 0:
  52. return
  53. for metric in eval_dict:
  54. if metric not in self.allowed_metric:
  55. continue
  56. if self.has_improved(
  57. eval_dict[metric],
  58. metric,
  59. self.cur_best_performance):
  60. self.cur_best_performance[metric] = eval_dict[metric]
  61. self.no_improvement_round[metric] = 0
  62. else:
  63. self.no_improvement_round[metric] += 1