123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from torch.utils.data import DataLoader
- from federatedml.framework.hetero.procedure import batch_generator
- from federatedml.nn.hetero.base import HeteroNNBase
- from federatedml.nn.hetero.model import HeteroNNHostModel
- from federatedml.param.hetero_nn_param import HeteroNNParam as NNParameter
- from federatedml.protobuf.generated.hetero_nn_model_meta_pb2 import HeteroNNMeta
- from federatedml.protobuf.generated.hetero_nn_model_param_pb2 import HeteroNNParam
- from federatedml.util import consts, LOGGER
- MODELMETA = "HeteroNNHostMeta"
- MODELPARAM = "HeteroNNHostParam"
- class HeteroNNHost(HeteroNNBase):
- def __init__(self):
- super(HeteroNNHost, self).__init__()
- self.batch_generator = batch_generator.Host()
- self.model = None
- self.role = consts.HOST
- self.input_shape = None
- self.default_table_partitions = 4
- def _init_model(self, hetero_nn_param):
- super(HeteroNNHost, self)._init_model(hetero_nn_param)
- def export_model(self):
- if self.need_cv:
- return None
- model = {MODELMETA: self._get_model_meta(),
- MODELPARAM: self._get_model_param()}
- return model
- def load_model(self, model_dict):
- model_dict = list(model_dict["model"].values())[0]
- param = model_dict.get(MODELPARAM)
- meta = model_dict.get(MODELMETA)
- if self.hetero_nn_param is None:
- self.hetero_nn_param = NNParameter()
- self.hetero_nn_param.check()
- self.predict_param = self.hetero_nn_param.predict_param
- self._build_model()
- self._restore_model_meta(meta)
- self._restore_model_param(param)
- def _build_model(self):
- self.model = HeteroNNHostModel(self.hetero_nn_param, self.flowid)
- self.model.set_transfer_variable(self.transfer_variable)
- self.model.set_partition(self.default_table_partitions)
- def predict(self, data_inst):
- ds = self.prepare_dataset(data_inst, data_type='predict')
- batch_size = len(ds) if self.batch_size == -1 else self.batch_size
- for batch_data in DataLoader(ds, batch_size=batch_size):
- # ignore label if the dataset offers label
- if isinstance(batch_data, tuple) and len(batch_data) > 1:
- batch_data = batch_data[0]
- self.model.predict(batch_data)
- def fit(self, data_inst, validate_data=None):
- if hasattr(
- data_inst,
- 'partitions') and data_inst.partitions is not None:
- self.default_table_partitions = data_inst.partitions
- LOGGER.debug(
- 'reset default partitions is {}'.format(
- self.default_table_partitions))
- train_ds = self.prepare_dataset(data_inst, data_type='train')
- if validate_data is not None:
- val_ds = self.prepare_dataset(validate_data, data_type='validate')
- else:
- val_ds = None
- self.callback_list.on_train_begin(train_ds, val_ds)
- if not self.component_properties.is_warm_start:
- self._build_model()
- epoch_offset = 0
- else:
- self.callback_warm_start_init_iter(self.history_iter_epoch)
- epoch_offset = self.history_iter_epoch + 1
- batch_size = len(train_ds) if self.batch_size == - \
- 1 else self.batch_size
- for cur_epoch in range(epoch_offset, epoch_offset + self.epochs):
- self.iter_epoch = cur_epoch
- for batch_idx, batch_data in enumerate(
- DataLoader(train_ds, batch_size=batch_size)):
- self.model.train(batch_data, cur_epoch, batch_idx)
- self.callback_list.on_epoch_end(cur_epoch)
- if self.callback_variables.stop_training:
- LOGGER.debug('early stopping triggered')
- break
- is_converge = self.transfer_variable.is_converge.get(
- idx=0, suffix=(cur_epoch,))
- if is_converge:
- LOGGER.debug(
- "Training process is converged in epoch {}".format(cur_epoch))
- break
- self.callback_list.on_train_end()
- def _get_model_meta(self):
- model_meta = HeteroNNMeta()
- model_meta.batch_size = self.batch_size
- model_meta.hetero_nn_model_meta.CopyFrom(
- self.model.get_hetero_nn_model_meta())
- model_meta.module = 'HeteroNN'
- return model_meta
- def _get_model_param(self):
- model_param = HeteroNNParam()
- model_param.iter_epoch = self.iter_epoch
- model_param.header.extend(self._header)
- model_param.hetero_nn_model_param.CopyFrom(
- self.model.get_hetero_nn_model_param())
- model_param.best_iteration = self.callback_variables.best_iteration
- return model_param
|