123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from fate_arch.computing.non_distributed import LocalData
- from federatedml.model_base import ModelBase
- from federatedml.model_selection import start_cross_validation
- from federatedml.nn.backend.utils.data import load_dataset
- from federatedml.nn.dataset.base import Dataset, ShuffleWrapDataset
- from federatedml.param.hetero_nn_param import HeteroNNParam
- from federatedml.transfer_variable.transfer_class.hetero_nn_transfer_variable import HeteroNNTransferVariable
- from federatedml.util import consts
- class HeteroNNBase(ModelBase):
- def __init__(self):
- super(HeteroNNBase, self).__init__()
- self.tol = None
- self.early_stop = None
- self.seed = 100
- self.epochs = None
- self.batch_size = None
- self._header = []
- self.predict_param = None
- self.hetero_nn_param = None
- self.batch_generator = None
- self.model = None
- self.partition = None
- self.validation_freqs = None
- self.early_stopping_rounds = None
- self.metrics = []
- self.use_first_metric_only = False
- self.transfer_variable = HeteroNNTransferVariable()
- self.model_param = HeteroNNParam()
- self.mode = consts.HETERO
- self.selector_param = None
- self.floating_point_precision = None
- self.history_iter_epoch = 0
- self.iter_epoch = 0
- self.data_x = []
- self.data_y = []
- self.dataset_cache_dict = {}
- self.label_num = None
- # nn related param
- self.top_model_define = None
- self.bottom_model_define = None
- self.interactive_layer_define = None
- self.dataset_shuffle = True
- self.dataset = None
- self.dataset_param = None
- self.dataset_shuffle_seed = 100
- def _init_model(self, hetero_nn_param: HeteroNNParam):
- self.interactive_layer_lr = hetero_nn_param.interactive_layer_lr
- self.epochs = hetero_nn_param.epochs
- self.batch_size = hetero_nn_param.batch_size
- self.seed = hetero_nn_param.seed
- self.early_stop = hetero_nn_param.early_stop
- self.validation_freqs = hetero_nn_param.validation_freqs
- self.early_stopping_rounds = hetero_nn_param.early_stopping_rounds
- self.metrics = hetero_nn_param.metrics
- self.use_first_metric_only = hetero_nn_param.use_first_metric_only
- self.tol = hetero_nn_param.tol
- self.predict_param = hetero_nn_param.predict_param
- self.hetero_nn_param = hetero_nn_param
- self.selector_param = hetero_nn_param.selector_param
- self.floating_point_precision = hetero_nn_param.floating_point_precision
- # nn configs
- self.bottom_model_define = hetero_nn_param.bottom_nn_define
- self.top_model_define = hetero_nn_param.top_nn_define
- self.interactive_layer_define = hetero_nn_param.interactive_layer_define
- # dataset
- dataset_param = hetero_nn_param.dataset.to_dict()
- self.dataset = dataset_param['dataset_name']
- self.dataset_param = dataset_param['param']
- def reset_flowid(self):
- new_flowid = ".".join([self.flowid, "evaluate"])
- self.set_flowid(new_flowid)
- def recovery_flowid(self):
- new_flowid = ".".join(self.flowid.split(".", -1)[: -1])
- self.set_flowid(new_flowid)
- def _build_bottom_model(self):
- pass
- def _build_interactive_model(self):
- pass
- def _restore_model_meta(self, meta):
- # self.hetero_nn_param.interactive_layer_lr = meta.interactive_layer_lr
- self.hetero_nn_param.task_type = meta.task_type
- if not self.component_properties.is_warm_start:
- self.batch_size = meta.batch_size
- self.epochs = meta.epochs
- self.tol = meta.tol
- self.early_stop = meta.early_stop
- self.model.set_hetero_nn_model_meta(meta.hetero_nn_model_meta)
- def _restore_model_param(self, param):
- self.model.set_hetero_nn_model_param(param.hetero_nn_model_param)
- self._header = list(param.header)
- self.history_iter_epoch = param.iter_epoch
- self.iter_epoch = param.iter_epoch
- def set_partition(self, data_inst):
- self.partition = data_inst.partitions
- self.model.set_partition(self.partition)
- def cross_validation(self, data_instances):
- return start_cross_validation.run(self, data_instances)
- def prepare_dataset(self, data, data_type='train', check_label=False):
- # train input & validate input are DTables or path str
- if isinstance(data, LocalData):
- data = data.path
- if isinstance(data, Dataset) or isinstance(data, ShuffleWrapDataset):
- ds = data
- else:
- ds = load_dataset(
- self.dataset,
- data,
- self.dataset_param,
- self.dataset_cache_dict)
- if not ds.has_sample_ids():
- raise ValueError(
- 'Dataset has no sample id, this is not allowed in hetero-nn, please make sure'
- ' that you implement get_sample_ids()')
- if self.dataset_shuffle:
- ds = ShuffleWrapDataset(
- ds, shuffle_seed=self.dataset_shuffle_seed)
- if self.role == consts.GUEST:
- self.transfer_variable.dataset_info.remote(
- ds.idx_map, idx=-1, suffix=('idx_map', data_type))
- if self.role == consts.HOST:
- idx_map = self.transfer_variable.dataset_info.get(
- idx=0, suffix=('idx_map', data_type))
- assert len(idx_map) == len(ds), 'host dataset len != guest dataset len, please check your dataset,' \
- 'guest len {}, host len {}'.format(len(idx_map), len(ds))
- ds.set_shuffled_idx(idx_map)
- if check_label:
- try:
- all_classes = ds.get_classes()
- except NotImplementedError as e:
- raise NotImplementedError(
- 'get_classes() is not implemented, please implement this function'
- ' when you are using hetero-nn. Let it return classes in a list.'
- ' Please see built-in dataset(table.py for example) for reference')
- except BaseException as e:
- raise e
- from federatedml.util import LOGGER
- LOGGER.debug('all classes is {}'.format(all_classes))
- if self.label_num is None:
- if self.task_type == consts.CLASSIFICATION:
- self.label_num = len(all_classes)
- elif self.task_type == consts.REGRESSION:
- self.label_num = 1
- return ds
- # override function
- @staticmethod
- def set_predict_data_schema(predict_datas, schemas):
- if predict_datas is None:
- return predict_datas
- if isinstance(predict_datas, list):
- predict_data = predict_datas[0]
- schema = schemas[0]
- else:
- predict_data = predict_datas
- schema = schemas
- if predict_data is not None:
- predict_data.schema = {
- "header": [
- "label",
- "predict_result",
- "predict_score",
- "predict_detail",
- "type",
- ],
- "sid": 'id',
- "content_type": "predict_result"
- }
- if schema.get("match_id_name") is not None:
- predict_data.schema["match_id_name"] = schema.get(
- "match_id_name")
- return predict_data
|