validation_strategy.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. ################################################################################
  19. #
  20. #
  21. ################################################################################
  22. import copy
  23. from federatedml.util import LOGGER
  24. from federatedml.util import consts
  25. from federatedml.param.evaluation_param import EvaluateParam
  26. from federatedml.evaluation.performance_recorder import PerformanceRecorder
  27. from federatedml.transfer_variable.transfer_class.validation_strategy_transfer_variable import \
  28. ValidationStrategyVariable
  29. from federatedml.callbacks.callback_base import CallbackBase
  30. from federatedml.feature.instance import Instance
  31. class ValidationStrategy(CallbackBase):
  32. """
  33. This module is used for evaluating the performance of model during training process.
  34. it will be called only in fit process of models.
  35. Attributes
  36. ----------
  37. validation_freqs: None or positive integer or container object in python. Do validation in training process or Not.
  38. if equals None, will not do validation in train process;
  39. if equals positive integer, will validate data every validation_freqs epochs passes;
  40. if container object in python, will validate data if epochs belong to this container.
  41. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  42. Default: None
  43. train_data: None or Table,
  44. if train_data not equal to None, and judge need to validate data according to validation_freqs,
  45. training data will be used for evaluating
  46. validate_data: None or Table,
  47. if validate_data not equal to None, and judge need to validate data according to validation_freqs,
  48. validate data will be used for evaluating
  49. """
  50. def __init__(self, role=None, mode=None, validation_freqs=None, early_stopping_rounds=None,
  51. use_first_metric_only=False, arbiter_comm=True):
  52. self.validation_freqs = validation_freqs
  53. self.role = role
  54. self.mode = mode
  55. self.flowid = ''
  56. self.train_data = None
  57. self.validate_data = None
  58. # early stopping related vars
  59. self.arbiter_comm = arbiter_comm
  60. self.sync_status = False
  61. self.early_stopping_rounds = early_stopping_rounds
  62. self.use_first_metric_only = use_first_metric_only
  63. self.first_metric = None
  64. self._evaluation_summary = {}
  65. # precompute scores
  66. self.cached_train_scores = None
  67. self.cached_validate_scores = None
  68. self.use_precompute_train_scores = False
  69. self.use_precompute_validate_scores = False
  70. if early_stopping_rounds is not None:
  71. if early_stopping_rounds <= 0:
  72. raise ValueError('early stopping error should be larger than 0')
  73. if self.mode == consts.HOMO:
  74. raise ValueError('early stopping is not supported for homo algorithms')
  75. self.sync_status = True
  76. LOGGER.debug("early stopping round is {}".format(self.early_stopping_rounds))
  77. self.cur_best_model = None
  78. self.best_iteration = -1
  79. self.metric_best_model = {} # best model of a certain metric
  80. self.metric_best_iter = {} # best iter of a certain metric
  81. self.performance_recorder = PerformanceRecorder() # recorder to record performances
  82. self.transfer_inst = ValidationStrategyVariable()
  83. def set_train_data(self, train_data):
  84. self.train_data = train_data
  85. def set_validate_data(self, validate_data):
  86. self.validate_data = validate_data
  87. if self.early_stopping_rounds and self.validate_data is None:
  88. raise ValueError('validate data is needed when early stopping is enabled')
  89. def set_flowid(self, flowid):
  90. self.flowid = flowid
  91. def need_run_validation(self, epoch):
  92. LOGGER.debug("validation_freqs is {}".format(self.validation_freqs))
  93. if not self.validation_freqs:
  94. return False
  95. if isinstance(self.validation_freqs, int):
  96. return (epoch + 1) % self.validation_freqs == 0
  97. return epoch in self.validation_freqs
  98. @staticmethod
  99. def generate_flowid(prefix, epoch, keywords="iteration", data_type="train"):
  100. return "_".join([prefix, keywords, str(epoch), data_type])
  101. @staticmethod
  102. def make_data_set_name(need_cv, need_run_ovr, model_flowid, epoch):
  103. data_iteration_name = "_".join(["iteration", str(epoch)])
  104. if not need_cv and not need_run_ovr:
  105. return data_iteration_name
  106. if need_cv:
  107. if not need_run_ovr:
  108. prefix = "_".join(["fold", model_flowid.split(".", -1)[-1]])
  109. else:
  110. prefix = "_".join(["fold", model_flowid.split(".", -1)[-2]])
  111. prefix = ".".join([prefix, model_flowid.split(".", -1)[-1]])
  112. else:
  113. prefix = model_flowid.split(".", -1)[-1]
  114. return ".".join([prefix, data_iteration_name])
  115. @staticmethod
  116. def extract_best_model(model):
  117. best_model = model.export_model()
  118. return {'model': {'best_model': best_model}} if best_model is not None else None
  119. def is_best_performance_updated(self, use_first_metric_only=False):
  120. if len(self.performance_recorder.no_improvement_round.items()) == 0:
  121. return False
  122. for metric, no_improve_val in self.performance_recorder.no_improvement_round.items():
  123. if no_improve_val != 0:
  124. return False
  125. if use_first_metric_only:
  126. break
  127. return True
  128. def update_early_stopping_status(self, iteration, model):
  129. first_metric = True
  130. if self.role == consts.GUEST:
  131. LOGGER.info('showing early stopping status, {} shows cur best performances: {}'.format(
  132. self.role, self.performance_recorder.cur_best_performance))
  133. LOGGER.info('showing early stopping status, {} shows early stopping no improve rounds: {}'.format(
  134. self.role, self.performance_recorder.no_improvement_round))
  135. for metric, no_improve_round in self.performance_recorder.no_improvement_round.items():
  136. if no_improve_round == 0:
  137. self.metric_best_iter[metric] = iteration
  138. self.metric_best_model[metric] = self.extract_best_model(model)
  139. LOGGER.info('best model of metric {} is now updated to {}'.format(metric, iteration))
  140. # if early stopping is not triggered, return best model of first metric by default
  141. if first_metric:
  142. LOGGER.info('default best model: metric {}, iter {}'.format(metric, iteration))
  143. self.cur_best_model = self.metric_best_model[metric]
  144. self.best_iteration = iteration
  145. first_metric = False
  146. def check_early_stopping(self,):
  147. """
  148. check if satisfy early_stopping_round
  149. Returns bool
  150. """
  151. LOGGER.info('checking early stopping')
  152. no_improvement_dict = self.performance_recorder.no_improvement_round
  153. for metric in no_improvement_dict:
  154. if no_improvement_dict[metric] >= self.early_stopping_rounds:
  155. self.best_iteration = self.metric_best_iter[metric]
  156. self.cur_best_model = self.metric_best_model[metric]
  157. LOGGER.info('early stopping triggered, model of iter {} is chosen because metric {} satisfied'
  158. 'stop condition'.format(self.best_iteration, metric))
  159. return True
  160. return False
  161. def sync_performance_recorder(self, epoch):
  162. """
  163. sync synchronize self.performance_recorder
  164. """
  165. if self.mode == consts.HETERO and self.role == consts.GUEST:
  166. recorder_to_send = copy.deepcopy(self.performance_recorder)
  167. recorder_to_send.cur_best_performance = None
  168. if self.arbiter_comm:
  169. self.transfer_inst.validation_status.remote(recorder_to_send, idx=-1, suffix=(epoch,))
  170. else:
  171. self.transfer_inst.validation_status.remote(recorder_to_send, idx=-1, suffix=(epoch,),
  172. role=consts.HOST)
  173. elif self.mode == consts.HETERO:
  174. self.performance_recorder = self.transfer_inst.validation_status.get(idx=-1, suffix=(epoch,))[0]
  175. else:
  176. return
  177. def need_stop(self):
  178. return False if not self.early_stopping_rounds else self.check_early_stopping()
  179. def has_saved_best_model(self):
  180. return (self.early_stopping_rounds is not None) and (self.cur_best_model is not None)
  181. def export_best_model(self):
  182. if self.has_saved_best_model():
  183. return self.cur_best_model
  184. else:
  185. return None
  186. def summary(self):
  187. return self._evaluation_summary
  188. def update_metric_summary(self, metric_dict):
  189. iter_name = list(metric_dict.keys())[0]
  190. metric_dict = metric_dict[iter_name]
  191. if len(self._evaluation_summary) == 0:
  192. self._evaluation_summary = {namespace: {} for namespace in metric_dict}
  193. for namespace in metric_dict:
  194. for metric_name in metric_dict[namespace]:
  195. epoch_metric = metric_dict[namespace][metric_name]
  196. if metric_name not in self._evaluation_summary[namespace]:
  197. self._evaluation_summary[namespace][metric_name] = []
  198. self._evaluation_summary[namespace][metric_name].append(epoch_metric)
  199. def evaluate(self, predicts, model, epoch):
  200. evaluate_param: EvaluateParam = model.get_metrics_param()
  201. evaluate_param.check_single_value_default_metric()
  202. from federatedml.evaluation.evaluation import Evaluation
  203. eval_obj = Evaluation()
  204. eval_type = evaluate_param.eval_type
  205. metric_list = evaluate_param.metrics
  206. if self.early_stopping_rounds and self.use_first_metric_only and len(metric_list) != 0:
  207. single_metric_list = None
  208. if eval_type == consts.BINARY:
  209. single_metric_list = consts.BINARY_SINGLE_VALUE_METRIC
  210. elif eval_type == consts.REGRESSION:
  211. single_metric_list = consts.REGRESSION_SINGLE_VALUE_METRICS
  212. elif eval_type == consts.MULTY:
  213. single_metric_list = consts.MULTI_SINGLE_VALUE_METRIC
  214. for metric in metric_list:
  215. if metric in single_metric_list:
  216. self.first_metric = metric
  217. LOGGER.debug('use {} as first metric'.format(self.first_metric))
  218. break
  219. eval_obj._init_model(evaluate_param)
  220. eval_obj.set_tracker(model.tracker)
  221. data_set_name = self.make_data_set_name(model.need_cv, model.callback_one_vs_rest, model.flowid, epoch)
  222. eval_data = {data_set_name: predicts}
  223. eval_result_dict = eval_obj.fit(eval_data, return_result=True)
  224. epoch_summary = eval_obj.summary()
  225. self.update_metric_summary(epoch_summary)
  226. eval_obj.save_data()
  227. LOGGER.debug("end of eval")
  228. return eval_result_dict
  229. @staticmethod
  230. def _add_data_type_map_func(value, data_type):
  231. new_pred_rs = Instance(features=value.features + [data_type], inst_id=value.inst_id)
  232. return new_pred_rs
  233. @staticmethod
  234. def add_data_type(predicts, data_type: str):
  235. """
  236. predict data add data_type
  237. """
  238. predicts = predicts.mapValues(lambda value: ValidationStrategy._add_data_type_map_func(value, data_type))
  239. return predicts
  240. def handle_precompute_scores(self, precompute_scores, data_type):
  241. if self.mode == consts.HETERO and self.role == consts.HOST:
  242. return None
  243. if self.role == consts.ARBITER:
  244. return None
  245. LOGGER.debug('using precompute scores')
  246. return self.add_data_type(precompute_scores, data_type)
  247. def get_predict_result(self, model, epoch, data, data_type: str):
  248. if not data:
  249. return
  250. LOGGER.debug("start to evaluate data {}".format(data_type))
  251. model_flowid = model.flowid
  252. # model_flowid = ".".join(model.flowid.split(".", -1)[1:])
  253. flowid = self.generate_flowid(model_flowid, epoch, "iteration", data_type)
  254. model.set_flowid(flowid)
  255. predicts = model.predict(data)
  256. model.set_flowid(model_flowid)
  257. if self.mode == consts.HOMO and self.role == consts.ARBITER:
  258. pass
  259. elif self.mode == consts.HETERO and self.role == consts.HOST:
  260. pass
  261. else:
  262. predicts = self.add_data_type(predicts, data_type)
  263. return predicts
  264. def set_precomputed_train_scores(self, train_scores):
  265. self.use_precompute_train_scores = True
  266. self.cached_train_scores = train_scores
  267. def set_precomputed_validate_scores(self, validate_scores):
  268. self.use_precompute_validate_scores = True
  269. self.cached_validate_scores = validate_scores
  270. def validate(self, model, epoch):
  271. """
  272. :param model: model instance, which has predict function
  273. :param epoch: int, epoch idx for generating flow id
  274. """
  275. LOGGER.debug(
  276. "begin to check validate status, need_run_validation is {}".format(
  277. self.need_run_validation(epoch)))
  278. if not self.need_run_validation(epoch):
  279. return
  280. if self.mode == consts.HOMO and self.role == consts.ARBITER:
  281. return
  282. if not self.use_precompute_train_scores: # call model.predict()
  283. train_predicts = self.get_predict_result(model, epoch, self.train_data, "train")
  284. else: # use precomputed scores
  285. train_predicts = self.handle_precompute_scores(self.cached_train_scores, 'train')
  286. if not self.use_precompute_validate_scores: # call model.predict()
  287. validate_predicts = self.get_predict_result(model, epoch, self.validate_data, "validate")
  288. else: # use precomputed scores
  289. validate_predicts = self.handle_precompute_scores(self.cached_validate_scores, 'validate')
  290. if train_predicts is not None or validate_predicts is not None:
  291. predicts = train_predicts
  292. if validate_predicts:
  293. predicts = predicts.union(validate_predicts)
  294. # running evaluation
  295. eval_result_dict = self.evaluate(predicts, model, epoch)
  296. LOGGER.debug('showing eval_result_dict here')
  297. LOGGER.debug(eval_result_dict)
  298. if self.early_stopping_rounds:
  299. if len(eval_result_dict) == 0:
  300. raise ValueError(
  301. "eval_result len is 0, no single value metric detected for early stopping checking")
  302. if self.use_first_metric_only:
  303. if self.first_metric:
  304. eval_result_dict = {self.first_metric: eval_result_dict[self.first_metric]}
  305. else:
  306. LOGGER.warning('use first metric only but no single metric found in metric list')
  307. self.performance_recorder.update(eval_result_dict)
  308. if self.sync_status:
  309. self.sync_performance_recorder(epoch)
  310. if self.early_stopping_rounds and self.mode == consts.HETERO:
  311. self.update_early_stopping_status(epoch, model)
  312. def on_train_begin(self, train_data=None, validate_data=None):
  313. if self.role != consts.ARBITER:
  314. self.set_train_data(train_data)
  315. self.set_validate_data(validate_data)
  316. def on_epoch_end(self, model, epoch):
  317. LOGGER.debug('running validation')
  318. self.validate(model, epoch)
  319. if self.need_stop():
  320. LOGGER.debug('early stopping triggered')
  321. model.callback_variables.stop_training = True
  322. def on_train_end(self, model):
  323. if self.has_saved_best_model():
  324. model.load_model(self.cur_best_model)
  325. model.callback_variables.best_iteration = self.best_iteration
  326. model.callback_variables.validation_summary = self.summary()