callback_list.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from federatedml.callbacks.validation_strategy import ValidationStrategy
  16. from federatedml.callbacks.model_checkpoint import ModelCheckpoint
  17. from federatedml.param.callback_param import CallbackParam
  18. from federatedml.util import LOGGER
  19. class CallbackList(object):
  20. def __init__(self, role, mode, model):
  21. self.role = role
  22. self.mode = mode
  23. self.model = model
  24. self.callback_list = []
  25. def init_callback_list(self, callback_param: CallbackParam):
  26. LOGGER.debug(f"self_model: {self.model}")
  27. if "EarlyStopping" in callback_param.callbacks or \
  28. "PerformanceEvaluate" in callback_param.callbacks:
  29. has_arbiter = self.model.component_properties.has_arbiter
  30. validation_strategy = ValidationStrategy(self.role, self.mode,
  31. callback_param.validation_freqs,
  32. callback_param.early_stopping_rounds,
  33. callback_param.use_first_metric_only,
  34. arbiter_comm=has_arbiter)
  35. self.callback_list.append(validation_strategy)
  36. if "ModelCheckpoint" in callback_param.callbacks:
  37. model_checkpoint = ModelCheckpoint(model=self.model,
  38. save_freq=callback_param.save_freq)
  39. self.callback_list.append(model_checkpoint)
  40. def get_validation_strategy(self):
  41. for callback_func in self.callback_list:
  42. if isinstance(callback_func, ValidationStrategy):
  43. return callback_func
  44. return None
  45. def on_train_begin(self, train_data=None, validate_data=None):
  46. for callback_func in self.callback_list:
  47. callback_func.on_train_begin(train_data, validate_data)
  48. def on_epoch_end(self, epoch):
  49. for callback_func in self.callback_list:
  50. callback_func.on_epoch_end(self.model, epoch)
  51. def on_epoch_begin(self, epoch):
  52. for callback_func in self.callback_list:
  53. callback_func.on_epoch_begin(self.model, epoch)
  54. def on_train_end(self):
  55. for callback_func in self.callback_list:
  56. callback_func.on_train_end(self.model)