123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from federatedml.callbacks.validation_strategy import ValidationStrategy
- from federatedml.callbacks.model_checkpoint import ModelCheckpoint
- from federatedml.param.callback_param import CallbackParam
- from federatedml.util import LOGGER
- class CallbackList(object):
- def __init__(self, role, mode, model):
- self.role = role
- self.mode = mode
- self.model = model
- self.callback_list = []
- def init_callback_list(self, callback_param: CallbackParam):
- LOGGER.debug(f"self_model: {self.model}")
- if "EarlyStopping" in callback_param.callbacks or \
- "PerformanceEvaluate" in callback_param.callbacks:
- has_arbiter = self.model.component_properties.has_arbiter
- validation_strategy = ValidationStrategy(self.role, self.mode,
- callback_param.validation_freqs,
- callback_param.early_stopping_rounds,
- callback_param.use_first_metric_only,
- arbiter_comm=has_arbiter)
- self.callback_list.append(validation_strategy)
- if "ModelCheckpoint" in callback_param.callbacks:
- model_checkpoint = ModelCheckpoint(model=self.model,
- save_freq=callback_param.save_freq)
- self.callback_list.append(model_checkpoint)
- def get_validation_strategy(self):
- for callback_func in self.callback_list:
- if isinstance(callback_func, ValidationStrategy):
- return callback_func
- return None
- def on_train_begin(self, train_data=None, validate_data=None):
- for callback_func in self.callback_list:
- callback_func.on_train_begin(train_data, validate_data)
- def on_epoch_end(self, epoch):
- for callback_func in self.callback_list:
- callback_func.on_epoch_end(self.model, epoch)
- def on_epoch_begin(self, epoch):
- for callback_func in self.callback_list:
- callback_func.on_epoch_begin(self.model, epoch)
- def on_train_end(self):
- for callback_func in self.callback_list:
- callback_func.on_train_end(self.model)
|