12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- from federatedml.model_base import ModelBase
- from federatedml.param.homo_nn_param import HomoNNParam
- from federatedml.nn.homo.trainer.trainer_base import get_trainer_class
- from federatedml.model_base import MetricMeta
- from federatedml.util import LOGGER
- from federatedml.nn.homo.client import NNModelExporter
- from federatedml.callbacks.model_checkpoint import ModelCheckpoint
- from federatedml.nn.backend.utils.common import get_homo_param_meta, recover_model_bytes
- class HomoNNServer(ModelBase):
- def __init__(self):
- super(HomoNNServer, self).__init__()
- self.model_param = HomoNNParam()
- self.trainer = None
- self.trainer_param = None
- # arbiter side models
- self.model = None
- self.model_loaded = False
- # arbiter saved extra status
- self.exporter = NNModelExporter()
- self.extra_data = {}
- # warm start
- self.warm_start_iter = None
- def export_model(self):
- if self.model is None:
- LOGGER.debug('export an empty model')
- return self.exporter.export_model_dict() # return an exporter
- return self.model
- def load_model(self, model_dict):
- if model_dict is not None:
- model_dict = list(model_dict["model"].values())[0]
- self.model = model_dict
- param, meta = get_homo_param_meta(self.model)
- # load extra data
- self.extra_data = recover_model_bytes(param.extra_data_bytes)
- self.warm_start_iter = param.epoch_idx
- def _init_model(self, param: HomoNNParam()):
- train_param = param.trainer.to_dict()
- self.trainer = train_param['trainer_name']
- self.trainer_param = train_param['param']
- LOGGER.debug('trainer and trainer param {} {}'.format(
- self.trainer, self.trainer_param))
- def fit(self, data_instance=None, validate_data=None):
- # fate loss callback setting
- self.callback_meta(
- "loss", "train", MetricMeta(
- name="train", metric_type="LOSS", extra_metas={
- "unit_name": "aggregate_round"}))
- # display warmstart iter
- if self.component_properties.is_warm_start:
- self.callback_warm_start_init_iter(self.warm_start_iter)
- # initialize trainer
- trainer_class = get_trainer_class(self.trainer)
- LOGGER.info('trainer class is {}'.format(trainer_class))
- # init trainer
- trainer_inst = trainer_class(**self.trainer_param)
- # set tracker for fateboard callback
- trainer_inst.set_tracker(self.tracker)
- # set exporter
- trainer_inst.set_model_exporter(self.exporter)
- # set chceckpoint
- trainer_inst.set_checkpoint(ModelCheckpoint(self, save_freq=1))
- # run trainer server procedure
- trainer_inst.server_aggregate_procedure(self.extra_data)
- # aggregation process is done, get exported model if any
- self.model = trainer_inst.get_cached_model()
- self.set_summary(trainer_inst.get_summary())
- def predict(self, data_inst):
- return None
|