server.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from federatedml.model_base import ModelBase
  2. from federatedml.param.homo_nn_param import HomoNNParam
  3. from federatedml.nn.homo.trainer.trainer_base import get_trainer_class
  4. from federatedml.model_base import MetricMeta
  5. from federatedml.util import LOGGER
  6. from federatedml.nn.homo.client import NNModelExporter
  7. from federatedml.callbacks.model_checkpoint import ModelCheckpoint
  8. from federatedml.nn.backend.utils.common import get_homo_param_meta, recover_model_bytes
  9. class HomoNNServer(ModelBase):
  10. def __init__(self):
  11. super(HomoNNServer, self).__init__()
  12. self.model_param = HomoNNParam()
  13. self.trainer = None
  14. self.trainer_param = None
  15. # arbiter side models
  16. self.model = None
  17. self.model_loaded = False
  18. # arbiter saved extra status
  19. self.exporter = NNModelExporter()
  20. self.extra_data = {}
  21. # warm start
  22. self.warm_start_iter = None
  23. def export_model(self):
  24. if self.model is None:
  25. LOGGER.debug('export an empty model')
  26. return self.exporter.export_model_dict() # return an exporter
  27. return self.model
  28. def load_model(self, model_dict):
  29. if model_dict is not None:
  30. model_dict = list(model_dict["model"].values())[0]
  31. self.model = model_dict
  32. param, meta = get_homo_param_meta(self.model)
  33. # load extra data
  34. self.extra_data = recover_model_bytes(param.extra_data_bytes)
  35. self.warm_start_iter = param.epoch_idx
  36. def _init_model(self, param: HomoNNParam()):
  37. train_param = param.trainer.to_dict()
  38. self.trainer = train_param['trainer_name']
  39. self.trainer_param = train_param['param']
  40. LOGGER.debug('trainer and trainer param {} {}'.format(
  41. self.trainer, self.trainer_param))
  42. def fit(self, data_instance=None, validate_data=None):
  43. # fate loss callback setting
  44. self.callback_meta(
  45. "loss", "train", MetricMeta(
  46. name="train", metric_type="LOSS", extra_metas={
  47. "unit_name": "aggregate_round"}))
  48. # display warmstart iter
  49. if self.component_properties.is_warm_start:
  50. self.callback_warm_start_init_iter(self.warm_start_iter)
  51. # initialize trainer
  52. trainer_class = get_trainer_class(self.trainer)
  53. LOGGER.info('trainer class is {}'.format(trainer_class))
  54. # init trainer
  55. trainer_inst = trainer_class(**self.trainer_param)
  56. # set tracker for fateboard callback
  57. trainer_inst.set_tracker(self.tracker)
  58. # set exporter
  59. trainer_inst.set_model_exporter(self.exporter)
  60. # set chceckpoint
  61. trainer_inst.set_checkpoint(ModelCheckpoint(self, save_freq=1))
  62. # run trainer server procedure
  63. trainer_inst.server_aggregate_procedure(self.extra_data)
  64. # aggregation process is done, get exported model if any
  65. self.model = trainer_inst.get_cached_model()
  66. self.set_summary(trainer_inst.get_summary())
  67. def predict(self, data_inst):
  68. return None