123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- import json
- import torch
- import tempfile
- import inspect
- from fate_arch.computing.non_distributed import LocalData
- from fate_arch.computing._util import is_table
- from federatedml.model_base import ModelBase
- from federatedml.nn.homo.trainer.trainer_base import get_trainer_class, TrainerBase
- from federatedml.nn.backend.utils.data import load_dataset
- from federatedml.param.homo_nn_param import HomoNNParam
- from federatedml.nn.backend.torch import serialization as s
- from federatedml.nn.backend.torch.base import FateTorchOptimizer
- from federatedml.model_base import MetricMeta
- from federatedml.util import LOGGER
- from federatedml.util import consts
- from federatedml.nn.homo.trainer.trainer_base import StdReturnFormat
- from federatedml.nn.backend.utils.common import global_seed, get_homo_model_dict, get_homo_param_meta, recover_model_bytes, get_torch_model_bytes
- from federatedml.callbacks.model_checkpoint import ModelCheckpoint
- from federatedml.statistic.data_overview import check_with_inst_id
- from federatedml.nn.homo.trainer.trainer_base import ExporterBase
- from fate_arch.session import computing_session
- from federatedml.nn.backend.utils.data import get_ret_predict_table
- from federatedml.nn.dataset.table import TableDataset
- from federatedml.nn.backend.utils.data import add_match_id
- from federatedml.protobuf.generated.homo_nn_model_param_pb2 import HomoNNParam as HomoNNParamPB
- from federatedml.protobuf.generated.homo_nn_model_meta_pb2 import HomoNNMeta as HomoNNMetaPB
- class NNModelExporter(ExporterBase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- def export_model_dict(self, model=None, optimizer=None, model_define=None, optimizer_define=None, loss_define=None,
- epoch_idx=-1, converge_status=False, loss_history=None, best_epoch=-1, extra_data={}):
- if issubclass(type(model), torch.nn.Module):
- model_statedict = model.state_dict()
- else:
- model_statedict = None
- opt_state_dict = None
- if optimizer is not None:
- assert isinstance(optimizer, torch.optim.Optimizer), \
- 'optimizer must be an instance of torch.optim.Optimizer'
- opt_state_dict = optimizer.state_dict()
- model_status = {
- 'model': model_statedict,
- 'optimizer': opt_state_dict,
- }
- model_saved_bytes = get_torch_model_bytes(model_status)
- extra_data_bytes = get_torch_model_bytes(extra_data)
- param = HomoNNParamPB()
- meta = HomoNNMetaPB()
- # save param
- param.model_bytes = model_saved_bytes
- param.extra_data_bytes = extra_data_bytes
- param.epoch_idx = epoch_idx
- param.converge_status = converge_status
- param.best_epoch = best_epoch
- if loss_history is None:
- loss_history = []
- param.loss_history.extend(loss_history)
- # save meta
- meta.nn_define.append(json.dumps(model_define))
- meta.optimizer_define.append(json.dumps(optimizer_define))
- meta.loss_func_define.append(json.dumps(loss_define))
- return get_homo_model_dict(param, meta)
- class HomoNNClient(ModelBase):
- def __init__(self):
- super(HomoNNClient, self).__init__()
- self.model_param = HomoNNParam()
- self.trainer = consts.FEDAVG_TRAINER
- self.trainer_param = {}
- self.dataset_module = None
- self.dataset = None
- self.dataset_param = {}
- self.torch_seed = None
- self.loss = None
- self.optimizer = None
- self.nn_define = None
- # running varialbles
- self.trainer_inst = None
- # export model
- self.exporter = NNModelExporter()
- self.model_loaded = False
- self.model = None
- # cache dataset
- self.cache_dataset = {}
- # dtable partitions
- self.partitions = 4
- # warm start display iter
- self.warm_start_iter = None
- def _init_model(self, param: HomoNNParam):
- train_param = param.trainer.to_dict()
- dataset_param = param.dataset.to_dict()
- self.trainer = train_param['trainer_name']
- self.dataset = dataset_param['dataset_name']
- self.trainer_param = train_param['param']
- self.dataset_param = dataset_param['param']
- self.torch_seed = param.torch_seed
- self.nn_define = param.nn_define
- self.loss = param.loss
- self.optimizer = param.optimizer
- def init(self):
- # set random seed
- global_seed(self.torch_seed)
- # load trainer class
- if self.trainer is None:
- raise ValueError(
- 'Trainer is not specified, please specify your trainer')
- trainer_class = get_trainer_class(self.trainer)
- LOGGER.info('trainer class is {}'.format(trainer_class))
- # recover model from model config / or recover from saved model param
- loaded_model_dict = None
- # if has model protobuf, load model config from protobuf
- load_opt_state_dict = False
- if self.model_loaded:
- param, meta = get_homo_param_meta(self.model)
- self.warm_start_iter = param.epoch_idx
- if param is None or meta is None:
- raise ValueError(
- 'model protobuf is None, make sure'
- 'that your trainer calls export_model() function to save models')
- if meta.nn_define[0] is None:
- raise ValueError(
- 'nn_define is None, model protobuf has no nn-define, make sure'
- 'that your trainer calls export_model() function to save models')
- self.nn_define = json.loads(meta.nn_define[0])
- loss = json.loads(meta.loss_func_define[0])
- optimizer = json.loads(meta.optimizer_define[0])
- loaded_model_dict = recover_model_bytes(param.model_bytes)
- extra_data = recover_model_bytes(param.extra_data_bytes)
- if self.optimizer is not None and optimizer != self.optimizer:
- LOGGER.info('optimizer updated')
- else:
- self.optimizer = optimizer
- load_opt_state_dict = True
- if self.loss is not None and self.loss != loss:
- LOGGER.info('loss updated')
- else:
- self.loss = loss
- else:
- extra_data = {}
- # check key param
- if self.nn_define is None:
- raise ValueError(
- 'Model structure is not defined, nn_define is None, please check your param')
- # get model from nn define
- model = s.recover_sequential_from_dict(self.nn_define)
- if loaded_model_dict:
- model.load_state_dict(loaded_model_dict['model'])
- LOGGER.info('load model state dict from check point')
- LOGGER.info('model structure is {}'.format(model))
- # init optimizer
- if self.optimizer is not None:
- optimizer_: FateTorchOptimizer = s.recover_optimizer_from_dict(
- self.optimizer)
- # pass model parameters to optimizer
- optimizer = optimizer_.to_torch_instance(model.parameters())
- if load_opt_state_dict:
- LOGGER.info('load optimizer state dict')
- optimizer.load_state_dict(loaded_model_dict['optimizer'])
- LOGGER.info('optimizer is {}'.format(optimizer))
- else:
- optimizer = None
- LOGGER.info('optimizer is not specified')
- # init loss
- if self.loss is not None:
- loss_fn = s.recover_loss_fn_from_dict(self.loss)
- LOGGER.info('loss function is {}'.format(loss_fn))
- else:
- loss_fn = None
- LOGGER.info('loss function is not specified')
- # init trainer
- trainer_inst: TrainerBase = trainer_class(**self.trainer_param)
- trainer_train_args = inspect.getfullargspec(trainer_inst.train).args
- args_format = [
- 'self',
- 'train_set',
- 'validate_set',
- 'optimizer',
- 'loss',
- 'extra_data'
- ]
- if len(trainer_train_args) < 6:
- raise ValueError(
- 'Train function of trainer should take 6 arguments :{}, but current trainer.train '
- 'only takes {} arguments: {}'.format(
- args_format, len(trainer_train_args), trainer_train_args))
- trainer_inst.set_nn_config(self.nn_define, self.optimizer, self.loss)
- trainer_inst.fed_mode = True
- return trainer_inst, model, optimizer, loss_fn, extra_data
- def fit(self, train_input, validate_input=None):
- LOGGER.debug('train input is {}'.format(train_input))
- # train input & validate input are DTables or path str
- if not is_table(train_input):
- if isinstance(train_input, LocalData):
- train_input = train_input.path
- assert train_input is not None, 'input train path is None!'
- if not is_table(validate_input):
- if isinstance(validate_input, LocalData):
- validate_input = validate_input.path
- assert validate_input is not None, 'input validate path is None!'
- # fate loss callback setting
- self.callback_meta(
- "loss",
- "train",
- MetricMeta(
- name="train",
- metric_type="LOSS",
- extra_metas={
- "unit_name": "epochs"}))
- # set random seed
- global_seed(self.torch_seed)
- self.trainer_inst, model, optimizer, loss_fn, extra_data = self.init()
- self.trainer_inst.set_model(model)
- self.trainer_inst.set_tracker(self.tracker)
- self.trainer_inst.set_model_exporter(self.exporter)
- # load dataset class
- dataset_inst = load_dataset(
- dataset_name=self.dataset,
- data_path_or_dtable=train_input,
- dataset_cache=self.cache_dataset,
- param=self.dataset_param
- )
- # set dataset prefix
- dataset_inst.set_type('train')
- LOGGER.info('train dataset instance is {}'.format(dataset_inst))
- if validate_input:
- val_dataset_inst = load_dataset(
- dataset_name=self.dataset,
- data_path_or_dtable=validate_input,
- dataset_cache=self.cache_dataset,
- param=self.dataset_param
- )
- if id(val_dataset_inst) != id(dataset_inst):
- dataset_inst.set_type('validate')
- LOGGER.info('validate dataset instance is {}'.format(dataset_inst))
- else:
- val_dataset_inst = None
- # display warmstart iter
- if self.component_properties.is_warm_start:
- self.callback_warm_start_init_iter(self.warm_start_iter)
- # set model check point
- self.trainer_inst.set_checkpoint(ModelCheckpoint(self, save_freq=1))
- # training
- self.trainer_inst.train(
- dataset_inst,
- val_dataset_inst,
- optimizer,
- loss_fn,
- extra_data
- )
- # training is done, get exported model
- self.model = self.trainer_inst.get_cached_model()
- self.set_summary(self.trainer_inst.get_summary())
- def predict(self, cpn_input):
- with_inst_id = False
- schema = None
- if not is_table(cpn_input):
- if isinstance(cpn_input, LocalData):
- cpn_input = cpn_input.path
- assert cpn_input is not None, 'input path is None!'
- elif is_table(cpn_input):
- with_inst_id = check_with_inst_id(cpn_input)
- schema = cpn_input.schema
- LOGGER.info('running predict')
- if self.trainer_inst is None:
- # init model
- self.trainer_inst, model, optimizer, loss_fn, _ = self.init()
- self.trainer_inst.set_model(model)
- self.trainer_inst.set_tracker(self.tracker)
- dataset_inst = load_dataset(
- dataset_name=self.dataset,
- data_path_or_dtable=cpn_input,
- dataset_cache=self.cache_dataset,
- param=self.dataset_param)
- if not dataset_inst.has_dataset_type():
- dataset_inst.set_type('predict')
- trainer_ret = self.trainer_inst.predict(dataset_inst)
- if trainer_ret is None or not isinstance(trainer_ret, StdReturnFormat):
- LOGGER.info(
- 'trainer did not return formatted predicted result, skip predict')
- return None
- id_table, pred_table, classes = trainer_ret()
- if with_inst_id: # set match id
- add_match_id(id_table=id_table, dataset_inst=dataset_inst)
- id_dtable, pred_dtable = get_ret_predict_table(
- id_table, pred_table, classes, self.partitions, computing_session)
- ret_table = self.predict_score_to_output(
- id_dtable, pred_dtable, classes)
- if schema is not None:
- self.set_predict_data_schema(ret_table, schema)
- return ret_table
- def export_model(self):
- if self.model is None:
- LOGGER.debug('export an empty model')
- return self.exporter.export_model_dict() # return an empty model
- return self.model
- def load_model(self, model_dict):
- model_dict = list(model_dict["model"].values())[0]
- self.model = model_dict
- self.model_loaded = True
- # override function
- @staticmethod
- def set_predict_data_schema(predict_datas, schemas):
- if predict_datas is None:
- return predict_datas
- if isinstance(predict_datas, list):
- predict_data = predict_datas[0]
- schema = schemas[0]
- else:
- predict_data = predict_datas
- schema = schemas
- if predict_data is not None:
- predict_data.schema = {
- "header": [
- "label",
- "predict_result",
- "predict_score",
- "predict_detail",
- "type",
- ],
- "sid": 'id',
- "content_type": "predict_result"
- }
- if schema.get("match_id_name") is not None:
- predict_data.schema["match_id_name"] = schema.get(
- "match_id_name")
- return predict_data
|