123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from pipeline.component.component_base import FateComponent
- from pipeline.component.nn.models.sequantial import Sequential
- from pipeline.component.nn.backend.torch.interactive import InteractiveLayer
- from pipeline.interface import Input
- from pipeline.interface import Output
- from pipeline.utils.tools import extract_explicit_parameter
- from pipeline.component.nn.interface import DatasetParam
- class HeteroNN(FateComponent):
- @extract_explicit_parameter
- def __init__(self, task_type="classification", epochs=None, batch_size=-1, early_stop="diff",
- tol=1e-5, encrypt_param=None, predict_param=None, cv_param=None, interactive_layer_lr=0.1,
- validation_freqs=None, early_stopping_rounds=None, use_first_metric_only=None,
- floating_point_precision=23, selector_param=None, seed=100,
- dataset: DatasetParam = DatasetParam(dataset_name='table'), **kwargs
- ):
- """
- Parameters used for Hetero Neural Network.
- Parameters
- ----------
- task_type: str, task type of hetero nn model, one of 'classification', 'regression'.
- interactive_layer_lr: float, the learning rate of interactive layer.
- epochs: int, the maximum iteration for aggregation in training.
- batch_size : int, batch size when updating model.
- -1 means use all data in a batch. i.e. Not to use mini-batch strategy.
- defaults to -1.
- early_stop : str, accept 'diff' only in this version, default: 'diff'
- Method used to judge converge or not.
- a) diff: Use difference of loss between two iterations to judge whether converge.
- tol: float, tolerance val for early stop
- floating_point_precision: None or integer, if not None, means use floating_point_precision-bit to speed up calculation,
- e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
- the result by 2**floating_point_precision in the end.
- callback_param: dict, CallbackParam, see federatedml/param/callback_param
- encrypt_param: dict, see federatedml/param/encrypt_param
- dataset_param: dict, interface defining the dataset param
- early_stopping_rounds: integer larger than 0
- will stop training if one metric of one validation data
- doesn’t improve in last early_stopping_round rounds,
- need to set validation freqs and will check early_stopping every at every validation epoch
- validation_freqs: None or positive integer or container object in python
- Do validation in training process or Not.
- if equals None, will not do validation in train process;
- if equals positive integer, will validate data every validation_freqs epochs passes;
- if container object in python, will validate data if epochs belong to this container.
- e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
- Default: None
- """
- explicit_parameters = kwargs["explict_parameters"]
- explicit_parameters["optimizer"] = None
- explicit_parameters["bottom_nn_define"] = None
- explicit_parameters["top_nn_define"] = None
- explicit_parameters["interactive_layer_define"] = None
- explicit_parameters["loss"] = None
- FateComponent.__init__(self, **explicit_parameters)
- if "name" in explicit_parameters:
- del explicit_parameters["name"]
- for param_key, param_value in explicit_parameters.items():
- setattr(self, param_key, param_value)
- self.input = Input(self.name, data_type="multi")
- self.output = Output(self.name, data_type='single')
- self._module_name = "HeteroNN"
- self.optimizer = None
- self.bottom_nn_define = None
- self.top_nn_define = None
- self.interactive_layer_define = None
- # model holder
- self._bottom_nn_model = Sequential()
- self._interactive_layer = Sequential()
- self._top_nn_model = Sequential()
- # role
- self._role = 'common' # common/guest/host
- if hasattr(self, 'dataset'):
- assert isinstance(
- self.dataset, DatasetParam), 'dataset must be a DatasetParam class'
- self.dataset.check()
- self.dataset: DatasetParam = self.dataset.to_dict()
- def set_role(self, role):
- self._role = role
- def get_party_instance(self, role="guest", party_id=None) -> 'Component':
- inst = super().get_party_instance(role, party_id)
- inst.set_role(role)
- return inst
- def add_dataset(self, dataset_param: DatasetParam):
- assert isinstance(
- dataset_param, DatasetParam), 'dataset must be a DatasetParam class'
- dataset_param.check()
- self.dataset: DatasetParam = dataset_param.to_dict()
- self._component_parameter_keywords.add("dataset")
- self._component_param["dataset"] = self.dataset
- def add_bottom_model(self, model):
- if not hasattr(self, "_bottom_nn_model"):
- setattr(self, "_bottom_nn_model", Sequential())
- self._bottom_nn_model.add(model)
- def set_interactive_layer(self, layer):
- if self._role == 'common' or self._role == 'guest':
- if not hasattr(self, "_interactive_layer"):
- setattr(self, "_interactive_layer", Sequential())
- assert isinstance(layer, InteractiveLayer), 'You need to add an interactive layer instance, \n' \
- 'you can access InteractiveLayer by:\n' \
- 't.nn.InteractiveLayer after fate_torch_hook(t)\n' \
- 'or from pipeline.component.nn.backend.torch.interactive ' \
- 'import InteractiveLayer'
- self._interactive_layer.add(layer)
- else:
- raise RuntimeError(
- 'You can only set interactive layer in "common" or "guest" hetero nn component')
- def add_top_model(self, model):
- if self._role == 'host':
- raise RuntimeError('top model is not allow to set on host model')
- if not hasattr(self, "_top_nn_model"):
- setattr(self, "_top_nn_model", Sequential())
- self._top_nn_model.add(model)
- def _set_optimizer(self, opt):
- assert hasattr(
- opt, 'to_dict'), 'opt does not have function to_dict(), remember to call fate_torch_hook(t)'
- self.optimizer = opt.to_dict()
- def _set_loss(self, loss):
- assert hasattr(
- loss, 'to_dict'), 'loss does not have function to_dict(), remember to call fate_torch_hook(t)'
- loss_conf = loss.to_dict()
- setattr(self, "loss", loss_conf)
- def compile(self, optimizer, loss):
- self._set_optimizer(optimizer)
- self._set_loss(loss)
- self._compile_common_network_config()
- self._compile_role_network_config()
- self._compile_interactive_layer()
- def _compile_interactive_layer(self):
- if hasattr(
- self,
- "_interactive_layer") and not self._interactive_layer.is_empty():
- self.interactive_layer_define = self._interactive_layer.get_network_config()
- self._component_param["interactive_layer_define"] = self.interactive_layer_define
- def _compile_common_network_config(self):
- if hasattr(
- self,
- "_bottom_nn_model") and not self._bottom_nn_model.is_empty():
- self.bottom_nn_define = self._bottom_nn_model.get_network_config()
- self._component_param["bottom_nn_define"] = self.bottom_nn_define
- if hasattr(
- self,
- "_top_nn_model") and not self._top_nn_model.is_empty():
- self.top_nn_define = self._top_nn_model.get_network_config()
- self._component_param["top_nn_define"] = self.top_nn_define
- def _compile_role_network_config(self):
- all_party_instance = self._get_all_party_instance()
- for role in all_party_instance:
- for party in all_party_instance[role]["party"].keys():
- all_party_instance[role]["party"][party]._compile_common_network_config(
- )
- all_party_instance[role]["party"][party]._compile_interactive_layer(
- )
- def get_bottom_model(self):
- if hasattr(
- self,
- "_bottom_nn_model") and not getattr(
- self,
- "_bottom_nn_model").is_empty():
- return getattr(self, "_bottom_nn_model").get_model()
- bottom_models = {}
- all_party_instance = self._get_all_party_instance()
- for role in all_party_instance.keys():
- for party in all_party_instance[role]["party"].keys():
- party_inst = all_party_instance[role]["party"][party]
- if party_inst is not None:
- btn_model = all_party_instance[role]["party"][party].get_bottom_model(
- )
- if btn_model is not None:
- bottom_models[party] = btn_model
- return bottom_models if len(bottom_models) > 0 else None
- def get_top_model(self):
- if hasattr(
- self,
- "_top_nn_model") and not getattr(
- self,
- "_top_nn_model").is_empty():
- return getattr(self, "_top_nn_model").get_model()
- models = {}
- all_party_instance = self._get_all_party_instance()
- for role in all_party_instance.keys():
- for party in all_party_instance[role]["party"].keys():
- party_inst = all_party_instance[role]["party"][party]
- if party_inst is not None:
- top_model = all_party_instance[role]["party"][party].get_top_model(
- )
- if top_model is not None:
- models[party] = top_model
- return models if len(models) > 0 else None
- def __getstate__(self):
- state = dict(self.__dict__)
- if "_bottom_nn_model" in state:
- del state["_bottom_nn_model"]
- if "_interactive_layer" in state:
- del state["_interactive_layer"]
- if "_top_nn_model" in state:
- del state["_top_nn_model"]
- return state
|