123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
- from fate_arch.computing._util import is_table
- from fate_arch.session import computing_session as session
- from federatedml.feature.instance import Instance
- from federatedml.framework.hetero.procedure import batch_generator
- from federatedml.model_base import Metric
- from federatedml.model_base import MetricMeta
- from federatedml.nn.hetero.base import HeteroNNBase
- from federatedml.nn.hetero.model import HeteroNNGuestModel
- from federatedml.optim.convergence import converge_func_factory
- from federatedml.param.evaluation_param import EvaluateParam
- from federatedml.param.hetero_nn_param import HeteroNNParam as NNParameter
- from federatedml.protobuf.generated.hetero_nn_model_meta_pb2 import HeteroNNMeta
- from federatedml.protobuf.generated.hetero_nn_model_param_pb2 import HeteroNNParam
- from federatedml.util import consts, LOGGER
- from federatedml.util.io_check import assert_io_num_rows_equal
- from federatedml.nn.dataset.table import TableDataset
- from federatedml.statistic.data_overview import check_with_inst_id
- from federatedml.nn.backend.utils.data import add_match_id
- MODELMETA = "HeteroNNGuestMeta"
- MODELPARAM = "HeteroNNGuestParam"
- class HeteroNNGuest(HeteroNNBase):
- def __init__(self):
- super(HeteroNNGuest, self).__init__()
- self.task_type = None
- self.converge_func = None
- self.batch_generator = batch_generator.Guest()
- self.data_keys = []
- self.label_dict = {}
- self.model = None
- self.role = consts.GUEST
- self.history_loss = []
- self.input_shape = None
- self._summary_buf = {"history_loss": [],
- "is_converged": False,
- "best_iteration": -1}
- self.dataset_cache_dict = {}
- self.default_table_partitions = 4
- def _init_model(self, hetero_nn_param):
- super(HeteroNNGuest, self)._init_model(hetero_nn_param)
- self.task_type = hetero_nn_param.task_type
- self.converge_func = converge_func_factory(self.early_stop, self.tol)
- def _build_model(self):
- self.model = HeteroNNGuestModel(
- self.hetero_nn_param, self.component_properties, self.flowid)
- self.model.set_transfer_variable(self.transfer_variable)
- self.model.set_partition(self.default_table_partitions)
- def _set_loss_callback_info(self):
- self.callback_meta("loss",
- "train",
- MetricMeta(name="train",
- metric_type="LOSS",
- extra_metas={"unit_name": "iters"}))
- @staticmethod
- def _disable_sample_weight(dataset):
- # currently not support sample weight
- if isinstance(dataset, TableDataset):
- dataset.with_sample_weight = False
- def fit(self, data_inst, validate_data=None):
- if hasattr(
- data_inst,
- 'partitions') and data_inst.partitions is not None:
- self.default_table_partitions = data_inst.partitions
- LOGGER.debug(
- 'reset default partitions is {}'.format(
- self.default_table_partitions))
- train_ds = self.prepare_dataset(
- data_inst, data_type='train', check_label=True)
- train_ds.train() # set dataset to train mode
- self._disable_sample_weight(train_ds)
- if validate_data is not None:
- val_ds = self.prepare_dataset(validate_data, data_type='validate')
- val_ds.train() # set dataset to train mode
- self._disable_sample_weight(val_ds)
- else:
- val_ds = None
- self.callback_list.on_train_begin(train_ds, val_ds)
- # collect data from table to form data loader
- if not self.component_properties.is_warm_start:
- self._build_model()
- epoch_offset = 0
- else:
- self.callback_warm_start_init_iter(self.history_iter_epoch)
- epoch_offset = self.history_iter_epoch + 1
- # set label number
- self.model.set_label_num(self.label_num)
- if len(train_ds) == 0:
- self.model.set_empty()
- self._set_loss_callback_info()
- batch_size = len(train_ds) if self.batch_size == - \
- 1 else self.batch_size
- data_loader = DataLoader(
- train_ds,
- batch_size=batch_size,
- num_workers=4)
- for cur_epoch in range(epoch_offset, self.epochs + epoch_offset):
- self.iter_epoch = cur_epoch
- LOGGER.debug("cur epoch is {}".format(cur_epoch))
- self.callback_list.on_epoch_begin(cur_epoch)
- epoch_loss = 0
- acc_sample_num = 0
- for batch_idx, (batch_data, batch_label) in enumerate(data_loader):
- batch_loss = self.model.train(
- batch_data, batch_label, cur_epoch, batch_idx)
- if acc_sample_num + batch_size > len(train_ds):
- batch_len = len(train_ds) - acc_sample_num
- else:
- batch_len = batch_size
- acc_sample_num += batch_size
- epoch_loss += batch_loss * batch_len
- epoch_loss = epoch_loss / len(train_ds)
- LOGGER.debug("epoch {} loss is {}".format(cur_epoch, epoch_loss))
- self.callback_metric("loss",
- "train",
- [Metric(cur_epoch, epoch_loss)])
- self.history_loss.append(epoch_loss)
- self.callback_list.on_epoch_end(cur_epoch)
- if self.callback_variables.stop_training:
- LOGGER.debug('early stopping triggered')
- break
- if self.hetero_nn_param.selector_param.method:
- # when use selective bp, loss converge will be disabled
- is_converge = False
- else:
- is_converge = self.converge_func.is_converge(epoch_loss)
- self._summary_buf["is_converged"] = is_converge
- self.transfer_variable.is_converge.remote(is_converge,
- role=consts.HOST,
- idx=-1,
- suffix=(cur_epoch,))
- if is_converge:
- LOGGER.debug(
- "Training process is converged in epoch {}".format(cur_epoch))
- break
- self.callback_list.on_train_end()
- self.set_summary(self._get_model_summary())
- @assert_io_num_rows_equal
- def predict(self, data_inst):
- with_match_id = False
- if is_table(data_inst):
- with_match_id = check_with_inst_id(data_inst)
- ds = self.prepare_dataset(data_inst, data_type='predict')
- ds.eval() # set dataset to eval mode
- self._disable_sample_weight(ds)
- keys = ds.get_sample_ids()
- batch_size = len(ds) if self.batch_size == -1 else self.batch_size
- dl = DataLoader(ds, batch_size=batch_size)
- preds = []
- labels = []
- for batch_data, batch_label in dl:
- batch_pred = self.model.predict(batch_data)
- preds.append(batch_pred)
- labels.append(batch_label)
- preds = np.concatenate(preds, axis=0)
- labels = torch.concat(labels, dim=0).cpu().numpy().flatten().tolist()
- id_table = [(id_, Instance(label=l)) for id_, l in zip(keys, labels)]
- if with_match_id:
- add_match_id(id_table, ds.ds) # ds is wrap shuffle dataset here
- data_inst = session.parallelize(
- id_table,
- partition=self.default_table_partitions,
- include_key=True)
- if self.task_type == consts.REGRESSION:
- preds = preds.flatten().tolist()
- preds = [float(pred) for pred in preds]
- predict_tb = session.parallelize(zip(keys, preds), include_key=True,
- partition=self.default_table_partitions)
- result = self.predict_score_to_output(data_inst, predict_tb)
- else:
- if self.label_num > 2:
- preds = preds.tolist()
- preds = [list(map(float, pred)) for pred in preds]
- predict_tb = session.parallelize(zip(keys, preds), include_key=True,
- partition=self.default_table_partitions)
- result = self.predict_score_to_output(
- data_inst, predict_tb, classes=list(range(self.label_num)))
- else:
- preds = preds.flatten().tolist()
- preds = [float(pred) for pred in preds]
- predict_tb = session.parallelize(zip(keys, preds), include_key=True,
- partition=self.default_table_partitions)
- threshold = self.predict_param.threshold
- result = self.predict_score_to_output(
- data_inst, predict_tb, classes=[
- 0, 1], threshold=threshold)
- return result
- def export_model(self):
- if self.need_cv:
- return None
- model = {MODELMETA: self._get_model_meta(),
- MODELPARAM: self._get_model_param()}
- return model
- def load_model(self, model_dict):
- model_dict = list(model_dict["model"].values())[0]
- param = model_dict.get(MODELPARAM)
- meta = model_dict.get(MODELMETA)
- if self.hetero_nn_param is None:
- self.hetero_nn_param = NNParameter()
- self.hetero_nn_param.check()
- self.predict_param = self.hetero_nn_param.predict_param
- self._build_model()
- self._restore_model_meta(meta)
- self._restore_model_param(param)
- def _get_model_summary(self):
- self._summary_buf["history_loss"] = self.history_loss
- if self.callback_variables.validation_summary:
- self._summary_buf["validation_metrics"] = self.callback_variables.validation_summary
- """
- if self.validation_strategy:
- validation_summary = self.validation_strategy.summary()
- if validation_summary:
- self._summary_buf["validation_metrics"] = validation_summary
- """
- return self._summary_buf
- def _get_model_meta(self):
- model_meta = HeteroNNMeta()
- model_meta.task_type = self.task_type
- model_meta.module = 'HeteroNN'
- model_meta.batch_size = self.batch_size
- model_meta.epochs = self.epochs
- model_meta.early_stop = self.early_stop
- model_meta.tol = self.tol
- model_meta.hetero_nn_model_meta.CopyFrom(
- self.model.get_hetero_nn_model_meta())
- return model_meta
- def _get_model_param(self):
- model_param = HeteroNNParam()
- model_param.iter_epoch = self.iter_epoch
- model_param.hetero_nn_model_param.CopyFrom(
- self.model.get_hetero_nn_model_param())
- model_param.num_label = self.label_num
- model_param.best_iteration = self.callback_variables.best_iteration
- model_param.header.extend(self._header)
- for loss in self.history_loss:
- model_param.history_loss.append(loss)
- return model_param
- def get_metrics_param(self):
- if self.task_type == consts.CLASSIFICATION:
- if self.label_num == 2:
- return EvaluateParam(eval_type="binary",
- pos_label=1, metrics=self.metrics)
- else:
- return EvaluateParam(eval_type="multi", metrics=self.metrics)
- else:
- return EvaluateParam(eval_type="regression", metrics=self.metrics)
- def _restore_model_param(self, param):
- super(HeteroNNGuest, self)._restore_model_param(param)
- self.label_num = param.num_label
|