123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436 |
- import abc
- import importlib
- import torch as t
- import numpy as np
- from torch.nn import Module
- from typing import List
- from federatedml.util import consts
- from federatedml.util import LOGGER
- from federatedml.model_base import serialize_models
- from federatedml.nn.backend.utils.common import ML_PATH
- from federatedml.feature.instance import Instance
- from federatedml.evaluation.evaluation import Evaluation
- from federatedml.model_base import Metric, MetricMeta
- from federatedml.param import EvaluateParam
- class StdReturnFormat(object):
- def __init__(self, id_table_list, pred_table, classes):
- self.id = id_table_list
- self.pred_table = pred_table
- self.classes = classes
- def __call__(self,):
- return self.id, self.pred_table, self.classes
- class ExporterBase(object):
- def __init__(self, *args, **kwargs):
- pass
- def export_model_dict(self, model=None, optimizer=None, model_define=None, optimizer_define=None, loss_define=None,
- epoch_idx=-1, converge_status=False, loss_history=None, best_epoch=-1, extra_data={}):
- pass
- class TrainerBase(object):
- def __init__(self, **kwargs):
- self._fed_mode = True
- self.role = None
- self.party_id = None
- self.party_id_list = None
- self._flowid = None
- self._cache_model = None
- self._model = None
- self._tracker = None
- self._model_checkpoint = None
- self._exporter = None
- self._evaluation_summary = {}
- # running status
- self._set_model_checkpoint_epoch = set()
- # nn config
- self.nn_define, self.opt_define, self.loss_define = {}, {}, {}
- # ret summary
- self._summary = {}
- @staticmethod
- def is_pos_int(val):
- return val > 0 and isinstance(val, int)
- @staticmethod
- def is_float(val):
- return isinstance(val, float)
- @staticmethod
- def is_bool(val):
- return isinstance(val, bool)
- @staticmethod
- def check_trainer_param(
- var_list,
- name_list,
- judge_func,
- warning_str,
- allow_none=True):
- for var, name in zip(var_list, name_list):
- if allow_none and var is None:
- continue
- assert judge_func(var), warning_str.format(name)
- @property
- def model(self):
- if not hasattr(self, '_model'):
- raise AttributeError(
- 'model variable is not initialized, remember to call'
- ' super(your_class, self).__init__()')
- if self._model is None:
- raise AttributeError(
- 'model is not set, use set_model() function to set training model')
- return self._model
- @model.setter
- def model(self, val):
- self._model = val
- @property
- def fed_mode(self):
- if not hasattr(self, '_fed_mode'):
- raise AttributeError(
- 'run_local_mode variable is not initialized, remember to call'
- ' super(your_class, self).__init__()')
- return self._fed_mode
- @fed_mode.setter
- def fed_mode(self, val):
- assert isinstance(val, bool), 'fed mode must be a bool'
- self._fed_mode = val
- def local_mode(self):
- self.fed_mode = False
- def set_nn_config(self, nn_define, optimizer_define, loss_define):
- self.nn_define = nn_define
- self.opt_define = optimizer_define
- self.loss_define = loss_define
- def set_tracker(self, tracker):
- self._tracker = tracker
- def set_checkpoint(self, chkp):
- self._model_checkpoint = chkp
- def set_party_id_list(self, party_id_list):
- self.party_id_list = party_id_list
- def set_model_exporter(self, exporter):
- assert isinstance(
- exporter, ExporterBase), 'exporter is not an instance of ExporterBase'
- self._exporter = exporter
- def get_cached_model(self):
- return self._cache_model
- @staticmethod
- def task_type_infer(predict_result: t.Tensor, true_label):
- # infer task type and classes(of classification task)
- predict_result = predict_result.cpu()
- true_label = true_label.cpu()
- pred_shape = predict_result.shape
- with t.no_grad():
- if true_label.max() == 1.0 and true_label.min() == 0.0:
- return consts.BINARY
- if (len(pred_shape) > 1) and (pred_shape[1] > 1):
- if t.isclose(
- predict_result.sum(
- axis=1).cpu(), t.Tensor(
- [1.0])).all():
- return consts.MULTY
- else:
- return None
- elif (len(pred_shape) == 1) or (pred_shape[1] == 1):
- return consts.REGRESSION
- return None
- def _update_metric_summary(self, metric_dict):
- if len(metric_dict) == 0:
- return
- iter_name = list(metric_dict.keys())[0]
- metric_dict = metric_dict[iter_name]
- if len(self._evaluation_summary) == 0:
- self._evaluation_summary = {namespace: {}
- for namespace in metric_dict}
- for namespace in metric_dict:
- for metric_name in metric_dict[namespace]:
- epoch_metric = metric_dict[namespace][metric_name]
- if namespace not in self._evaluation_summary:
- self._evaluation_summary[namespace] = {}
- if metric_name not in self._evaluation_summary[namespace]:
- self._evaluation_summary[namespace][metric_name] = []
- self._evaluation_summary[namespace][metric_name].append(
- epoch_metric)
- def get_evaluation_summary(self):
- return self._evaluation_summary
- def get_summary(self):
- return self._summary
- """
- User Interfaces
- """
- def set_model(self, model: Module):
- if not issubclass(type(model), Module):
- raise ValueError('model must be a subclass of pytorch nn.Module')
- self.model = model
- def save(
- self,
- model=None,
- epoch_idx=-1,
- optimizer=None,
- converge_status=False,
- loss_history=None,
- best_epoch=-1,
- extra_data={}):
- assert isinstance(
- epoch_idx, int) and epoch_idx >= 0, 'epoch idx must be an int >= 0'
- if self._exporter:
- model_dict = self._exporter.export_model_dict(model=model,
- optimizer=optimizer,
- model_define=self.nn_define,
- optimizer_define=self.opt_define,
- loss_define=self.loss_define,
- epoch_idx=epoch_idx,
- converge_status=converge_status,
- loss_history=loss_history,
- best_epoch=best_epoch,
- extra_data=extra_data
- )
- self._cache_model = model_dict
- def checkpoint(
- self,
- epoch_idx,
- model=None,
- optimizer=None,
- converge_status=False,
- loss_history=None,
- best_epoch=-1,
- extra_data={}):
- assert isinstance(
- epoch_idx, int) and epoch_idx >= 0, 'epoch idx must be an int >= 0'
- if self._model_checkpoint:
- if self._exporter is None:
- raise RuntimeError('exporter is None, cannot save checkpoint')
- if epoch_idx in self._set_model_checkpoint_epoch:
- LOGGER.info(
- 'checkpoint at epoch {} set, skip setting checkpoint'.format(epoch_idx))
- return
- self.save(model=model, epoch_idx=epoch_idx, optimizer=optimizer, converge_status=converge_status,
- loss_history=loss_history, best_epoch=best_epoch, extra_data=extra_data)
- self._model_checkpoint.add_checkpoint(len(self._set_model_checkpoint_epoch),
- to_save_model=serialize_models(self._cache_model)) # step_index, to_save_model
- self._set_model_checkpoint_epoch.add(epoch_idx)
- LOGGER.info('checkpoint at epoch {} saved'.format(epoch_idx))
- def format_predict_result(self, sample_ids: List, predict_result: t.Tensor,
- true_label: t.Tensor, task_type: str = None):
- predict_result = predict_result.cpu().detach()
- if task_type == 'auto':
- task_type = self.task_type_infer(predict_result, true_label)
- if task_type is None:
- LOGGER.warning(
- 'unable to infer predict result type, predict process will be skipped')
- return None
- classes = None
- if task_type == consts.BINARY:
- classes = [0, 1]
- elif task_type == consts.MULTY:
- classes = [i for i in range(predict_result.shape[1])]
- true_label = true_label.cpu().detach().flatten().tolist()
- if task_type == consts.MULTY:
- predict_result = predict_result.tolist()
- else:
- predict_result = predict_result.flatten().tolist()
- id_table = [(id_, Instance(label=l))
- for id_, l in zip(sample_ids, true_label)]
- score_table = [(id_, pred)
- for id_, pred in zip(sample_ids, predict_result)]
- return StdReturnFormat(id_table, score_table, classes)
- def callback_metric(self, metric_name: str, value: float, metric_type='train', epoch_idx=0):
- assert metric_type in [
- 'train', 'validate'], 'metric_type should be train or validate'
- iter_name = 'iteration_{}'.format(epoch_idx)
- if self._tracker is not None:
- self._tracker.log_metric_data(
- metric_type, iter_name, [
- Metric(
- metric_name, np.round(
- value, 6))])
- self._tracker.set_metric_meta(
- metric_type, iter_name, MetricMeta(
- name=metric_name, metric_type='EVALUATION_SUMMARY'))
- def callback_loss(self, loss: float, epoch_idx: int):
- if self._tracker is not None:
- self._tracker.log_metric_data(
- metric_name="loss",
- metric_namespace="train",
- metrics=[Metric(epoch_idx, loss)],
- )
- def summary(self, summary_dict: dict):
- assert isinstance(summary_dict, dict), 'summary must be a dict'
- self._summary = summary_dict
- def evaluation(self, sample_ids: list, pred_scores: t.Tensor, label: t.Tensor, dataset_type='train',
- metric_list=None, epoch_idx=0, task_type=None):
- eval_obj = Evaluation()
- if task_type == 'auto':
- task_type = self.task_type_infer(pred_scores, label)
- if task_type is None:
- LOGGER.debug('cannot infer task type, return')
- return
- assert dataset_type in [
- 'train', 'validate'], 'dataset_type must in ["train", "validate"]'
- eval_param = EvaluateParam(eval_type=task_type)
- if task_type == consts.BINARY:
- eval_param.metrics = ['auc', 'ks']
- elif task_type == consts.MULTY:
- eval_param.metrics = ['accuracy', 'precision', 'recall']
- eval_param.check_single_value_default_metric()
- eval_obj._init_model(eval_param)
- pred_scores = pred_scores.cpu().detach().numpy()
- label = label.cpu().detach().numpy().flatten()
- if task_type == consts.REGRESSION or task_type == consts.BINARY:
- pred_scores = pred_scores.flatten()
- label = label.flatten()
- pred_scores = pred_scores.tolist()
- label = label.tolist()
- assert len(pred_scores) == len(
- label), 'the length of predict score != the length of label, pred {} and label {}'.format(len(pred_scores), len(label))
- eval_data = []
- for id_, s, l in zip(sample_ids, pred_scores, label):
- if task_type == consts.REGRESSION:
- eval_data.append([id_, (l, s, s)])
- if task_type == consts.MULTY:
- pred_label = np.argmax(s)
- eval_data.append([id_, (l, pred_label, s)])
- elif task_type == consts.BINARY:
- pred_label = (s > 0.5) + 1
- eval_data.append([id_, (l, pred_label, s)])
- eval_result = eval_obj.evaluate_metrics(dataset_type, eval_data)
- if self._tracker is not None:
- eval_obj.set_tracker(self._tracker)
- # send result to fate-board
- eval_obj.callback_metric_data(
- {'iteration_{}'.format(epoch_idx): [eval_result]})
- self._update_metric_summary(eval_obj.metric_summaries)
- return self._evaluation_summary
- def to_cuda(self, var):
- if hasattr(var, 'cuda'):
- return var.cuda()
- elif isinstance(var, tuple) or isinstance(var, list):
- ret = tuple(self.to_cuda(i) for i in var)
- return ret
- elif isinstance(var, dict):
- for k in var:
- if hasattr(var[k], 'cuda'):
- var[k] = var[k].cuda()
- return var
- else:
- return var
- @abc.abstractmethod
- def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
- """
- train_set : A Dataset Instance, must be a instance of subclass of Dataset (federatedml.nn.dataset.base),
- for example, TableDataset() (from federatedml.nn.dataset.table)
- validate_set : A Dataset Instance, but optional must be a instance of subclass of Dataset
- (federatedml.nn.dataset.base), for example, TableDataset() (from federatedml.nn.dataset.table)
- optimizer : A pytorch optimizer class instance, for example, t.optim.Adam(), t.optim.SGD()
- loss : A pytorch Loss class, for example, nn.BECLoss(), nn.CrossEntropyLoss()
- """
- pass
- @abc.abstractmethod
- def predict(self, dataset):
- pass
- @abc.abstractmethod
- def server_aggregate_procedure(self, extra_data={}):
- pass
- """
- Load Trainer
- """
- def get_trainer_class(trainer_module_name: str):
- if trainer_module_name.endswith('.py'):
- trainer_module_name = trainer_module_name.replace('.py', '')
- ds_modules = importlib.import_module(
- '{}.homo.trainer.{}'.format(
- ML_PATH, trainer_module_name))
- try:
- for k, v in ds_modules.__dict__.items():
- if isinstance(v, type):
- if issubclass(v, TrainerBase) and v is not TrainerBase:
- return v
- raise ValueError('Did not find any class in {}.py that is the subclass of Trainer class'.
- format(trainer_module_name))
- except ValueError as e:
- raise e
|