123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import copy
- import typing
- import numpy as np
- from google.protobuf import json_format
- from fate_arch.computing import is_table
- from federatedml.callbacks.callback_list import CallbackList
- from federatedml.feature.instance import Instance
- from federatedml.param.evaluation_param import EvaluateParam
- from federatedml.protobuf import deserialize_models
- from federatedml.statistic.data_overview import header_alignment, predict_detail_dict_to_str
- from federatedml.util import LOGGER, abnormal_detection
- from federatedml.util.anonymous_generator_util import Anonymous
- from federatedml.util.component_properties import ComponentProperties, RunningFuncs
- from federatedml.util.io_check import assert_match_id_consistent
- def serialize_models(models):
- serialized_models: typing.Dict[str, typing.Tuple[str, bytes, dict]] = {}
- for model_name, buffer_object in models.items():
- serialized_string = buffer_object.SerializeToString()
- pb_name = type(buffer_object).__name__
- json_format_dict = json_format.MessageToDict(
- buffer_object, including_default_value_fields=True
- )
- serialized_models[model_name] = (
- pb_name,
- serialized_string,
- json_format_dict,
- )
- return serialized_models
- class ComponentOutput:
- def __init__(self, data, models, cache: typing.List[tuple]) -> None:
- self._data = data
- if not isinstance(self._data, list):
- self._data = [data]
- self._models = models
- if self._models is None:
- self._models = {}
- self._cache = cache
- if not isinstance(self._cache, list):
- self._cache = [cache]
- @property
- def data(self) -> list:
- return self._data
- @property
- def model(self):
- return serialize_models(self._models)
- @property
- def cache(self):
- return self._cache
- class MetricType:
- LOSS = "LOSS"
- class Metric:
- def __init__(self, key, value: float, timestamp: float = None):
- self.key = key
- self.value = value
- self.timestamp = timestamp
- def to_dict(self):
- return dict(key=self.key, value=self.value, timestamp=self.timestamp)
- class MetricMeta:
- def __init__(self, name: str, metric_type: MetricType, extra_metas: dict = None):
- self.name = name
- self.metric_type = metric_type
- self.metas = {}
- self.extra_metas = extra_metas
- def update_metas(self, metas: dict):
- self.metas.update(metas)
- def to_dict(self):
- return dict(
- name=self.name,
- metric_type=self.metric_type,
- metas=self.metas,
- extra_metas=self.extra_metas,
- )
- class CallbacksVariable(object):
- def __init__(self):
- self.stop_training = False
- self.best_iteration = -1
- self.validation_summary = None
- class WarpedTrackerClient:
- def __init__(self, tracker) -> None:
- self._tracker = tracker
- def log_metric_data(
- self, metric_namespace: str, metric_name: str, metrics: typing.List[Metric]
- ):
- return self._tracker.log_metric_data(
- metric_namespace=metric_namespace,
- metric_name=metric_name,
- metrics=[metric.to_dict() for metric in metrics],
- )
- def set_metric_meta(
- self, metric_namespace: str, metric_name: str, metric_meta: MetricMeta
- ):
- return self._tracker.set_metric_meta(
- metric_namespace=metric_namespace,
- metric_name=metric_name,
- metric_meta=metric_meta.to_dict(),
- )
- def log_component_summary(self, summary_data: dict):
- return self._tracker.log_component_summary(summary_data=summary_data)
- class ModelBase(object):
- component_name = None
- @classmethod
- def set_component_name(cls, name):
- cls.component_name = name
- @classmethod
- def get_component_name(cls):
- return cls.component_name
- def __init__(self):
- self.model_output = None
- self.mode = None
- self.role = None
- self.data_output = None
- self.cache_output = None
- self.model_param = None
- self.transfer_variable = None
- self.flowid = ""
- self.task_version_id = ""
- self.need_one_vs_rest = False
- self.callback_one_vs_rest = False
- self.checkpoint_manager = None
- self.cv_fold = 0
- self.validation_freqs = None
- self.component_properties = ComponentProperties()
- self._summary = dict()
- self._align_cache = dict()
- self._tracker = None
- self.step_name = "step_name"
- self.callback_list: CallbackList
- self.callback_variables = CallbacksVariable()
- self.anonymous_generator = None
- @property
- def tracker(self) -> WarpedTrackerClient:
- if self._tracker is None:
- raise RuntimeError(f"use tracker before set")
- return self._tracker
- @tracker.setter
- def tracker(self, value):
- self._tracker = WarpedTrackerClient(value)
- @property
- def stop_training(self):
- return self.callback_variables.stop_training
- @property
- def need_cv(self):
- return self.component_properties.need_cv
- @property
- def need_run(self):
- return self.component_properties.need_run
- @need_run.setter
- def need_run(self, value: bool):
- self.component_properties.need_run = value
- def _init_model(self, model):
- pass
- def load_model(self, model_dict):
- pass
- def _parse_need_run(self, model_dict, model_meta_name):
- meta_obj = list(model_dict.get("model").values())[0].get(model_meta_name)
- need_run = meta_obj.need_run
- # self.need_run = need_run
- self.component_properties.need_run = need_run
- def run(self, cpn_input, retry: bool = True):
- self.task_version_id = cpn_input.task_version_id
- self.tracker = cpn_input.tracker
- self.checkpoint_manager = cpn_input.checkpoint_manager
- deserialize_models(cpn_input.models)
- # retry
- if (
- retry
- and hasattr(self, '_retry')
- and callable(self._retry)
- and self.checkpoint_manager is not None
- and self.checkpoint_manager.latest_checkpoint is not None
- ):
- self._retry(cpn_input=cpn_input)
- # normal
- else:
- self._run(cpn_input=cpn_input)
- return ComponentOutput(self.save_data(), self._export(), self.save_cache())
- def _export(self):
- # export model
- try:
- model = self._export_model()
- meta = self._export_meta()
- export_dict = {"Meta": meta, "Param": model}
- except NotImplementedError:
- export_dict = self.export_model()
- # export nothing, return
- if export_dict is None:
- return export_dict
- try:
- meta_name = [k for k in export_dict if k.endswith("Meta")][0]
- except BaseException:
- raise KeyError("Meta not found in export model")
- try:
- param_name = [k for k in export_dict if k.endswith("Param")][0]
- except BaseException:
- raise KeyError("Param not found in export model")
- meta = export_dict[meta_name]
- # set component name
- if hasattr(meta, "component"):
- meta.component = self.get_component_name()
- else:
- import warnings
- warnings.warn(f"{meta} should add `component` field")
- return export_dict
- def _export_meta(self):
- raise NotImplementedError("_export_meta not implemented")
- def _export_model(self):
- raise NotImplementedError("_export_model not implemented")
- def _run(self, cpn_input) -> None:
- # paramters
- self.model_param.update(cpn_input.parameters)
- self.model_param.check()
- self.component_properties.parse_component_param(
- cpn_input.roles, self.model_param
- )
- self.role = self.component_properties.role
- self.component_properties.parse_dsl_args(cpn_input.datasets, cpn_input.models)
- self.component_properties.parse_caches(cpn_input.caches)
- self.anonymous_generator = Anonymous(role=self.role, party_id=self.component_properties.local_partyid)
- # init component, implemented by subclasses
- self._init_model(self.model_param)
- self.callback_list = CallbackList(self.role, self.mode, self)
- if hasattr(self.model_param, "callback_param"):
- callback_param = getattr(self.model_param, "callback_param")
- self.callback_list.init_callback_list(callback_param)
- running_funcs = self.component_properties.extract_running_rules(
- datasets=cpn_input.datasets, models=cpn_input.models, cpn=self
- )
- LOGGER.debug(f"running_funcs: {running_funcs.todo_func_list}")
- saved_result = []
- for func, params, save_result, use_previews in running_funcs:
- # for func, params in zip(todo_func_list, todo_func_params):
- if use_previews:
- if params:
- real_param = [saved_result, params]
- else:
- real_param = saved_result
- LOGGER.debug("func: {}".format(func))
- this_data_output = func(*real_param)
- saved_result = []
- else:
- this_data_output = func(*params)
- if save_result:
- saved_result.append(this_data_output)
- if len(saved_result) == 1:
- self.data_output = saved_result[0]
- # LOGGER.debug("One data: {}".format(self.data_output.first()[1].features))
- LOGGER.debug(
- "saved_result is : {}, data_output: {}".format(
- saved_result, self.data_output
- )
- )
- # self.check_consistency()
- self.save_summary()
- def _retry(self, cpn_input) -> None:
- self.model_param.update(cpn_input.parameters)
- self.model_param.check()
- self.component_properties.parse_component_param(
- cpn_input.roles, self.model_param
- )
- self.role = self.component_properties.role
- self.component_properties.parse_dsl_args(cpn_input.datasets, cpn_input.models)
- self.component_properties.parse_caches(cpn_input.caches)
- # init component, implemented by subclasses
- self._init_model(self.model_param)
- self.callback_list = CallbackList(self.role, self.mode, self)
- if hasattr(self.model_param, "callback_param"):
- callback_param = getattr(self.model_param, "callback_param")
- self.callback_list.init_callback_list(callback_param)
- (
- train_data,
- validate_data,
- test_data,
- data,
- ) = self.component_properties.extract_input_data(
- datasets=cpn_input.datasets, model=self
- )
- running_funcs = RunningFuncs()
- latest_checkpoint = self.get_latest_checkpoint()
- running_funcs.add_func(self.load_model, [latest_checkpoint])
- running_funcs = self.component_properties.warm_start_process(
- running_funcs, self, train_data, validate_data
- )
- LOGGER.debug(f"running_funcs: {running_funcs.todo_func_list}")
- self._execute_running_funcs(running_funcs)
- def _execute_running_funcs(self, running_funcs):
- saved_result = []
- for func, params, save_result, use_previews in running_funcs:
- # for func, params in zip(todo_func_list, todo_func_params):
- if use_previews:
- if params:
- real_param = [saved_result, params]
- else:
- real_param = saved_result
- LOGGER.debug("func: {}".format(func))
- detected_func = assert_match_id_consistent(func)
- this_data_output = detected_func(*real_param)
- saved_result = []
- else:
- detected_func = assert_match_id_consistent(func)
- this_data_output = detected_func(*params)
- if save_result:
- saved_result.append(this_data_output)
- if len(saved_result) == 1:
- self.data_output = saved_result[0]
- LOGGER.debug(
- "saved_result is : {}, data_output: {}".format(
- saved_result, self.data_output
- )
- )
- self.save_summary()
- def export_serialized_models(self):
- return serialize_models(self.export_model())
- def get_metrics_param(self):
- return EvaluateParam(eval_type="binary", pos_label=1)
- def check_consistency(self):
- if not is_table(self.data_output):
- return
- if (
- self.component_properties.input_data_count
- + self.component_properties.input_eval_data_count
- != self.data_output.count()
- and self.component_properties.input_data_count
- != self.component_properties.input_eval_data_count
- ):
- raise ValueError("Input data count does not match with output data count")
- def predict(self, data_inst):
- pass
- def fit(self, *args):
- pass
- def transform(self, data_inst):
- pass
- def cross_validation(self, data_inst):
- pass
- def stepwise(self, data_inst):
- pass
- def one_vs_rest_fit(self, train_data=None):
- pass
- def one_vs_rest_predict(self, train_data):
- pass
- def init_validation_strategy(self, train_data=None, validate_data=None):
- pass
- def save_data(self):
- return self.data_output
- def export_model(self):
- return self.model_output
- def save_cache(self):
- return self.cache_output
- def set_flowid(self, flowid):
- # self.flowid = '.'.join([self.task_version_id, str(flowid)])
- self.flowid = flowid
- self.set_transfer_variable()
- def set_transfer_variable(self):
- if self.transfer_variable is not None:
- LOGGER.debug(
- "set flowid to transfer_variable, flowid: {}".format(self.flowid)
- )
- self.transfer_variable.set_flowid(self.flowid)
- def set_task_version_id(self, task_version_id):
- """task_version_id: jobid + component_name, reserved variable"""
- self.task_version_id = task_version_id
- def get_metric_name(self, name_prefix):
- if not self.need_cv:
- return name_prefix
- return "_".join(map(str, [name_prefix, self.flowid]))
- def set_tracker(self, tracker):
- self._tracker = tracker
- def set_checkpoint_manager(self, checkpoint_manager):
- checkpoint_manager.load_checkpoints_from_disk()
- self.checkpoint_manager = checkpoint_manager
- @staticmethod
- def set_predict_data_schema(predict_datas, schemas):
- if predict_datas is None:
- return predict_datas
- if isinstance(predict_datas, list):
- predict_data = predict_datas[0]
- schema = schemas[0]
- else:
- predict_data = predict_datas
- schema = schemas
- if predict_data is not None:
- predict_data.schema = {
- "header": [
- "label",
- "predict_result",
- "predict_score",
- "predict_detail",
- "type",
- ],
- "sid": schema.get("sid"),
- "content_type": "predict_result"
- }
- if schema.get("match_id_name") is not None:
- predict_data.schema["match_id_name"] = schema.get("match_id_name")
- return predict_data
- @staticmethod
- def predict_score_to_output(
- data_instances, predict_score, classes=None, threshold=0.5
- ):
- """
- Get predict result output
- Parameters
- ----------
- data_instances: table, data used for prediction
- predict_score: table, probability scores
- classes: list or None, all classes/label names
- threshold: float, predict threshold, used for binary label
- Returns
- -------
- Table, predict result
- """
- # regression
- if classes is None:
- predict_result = data_instances.join(
- predict_score, lambda d, pred: [d.label,
- pred,
- pred,
- predict_detail_dict_to_str({"label": pred})]
- )
- # binary
- elif isinstance(classes, list) and len(classes) == 2:
- class_neg, class_pos = classes[0], classes[1]
- pred_label = predict_score.mapValues(
- lambda x: class_pos if x > threshold else class_neg
- )
- predict_result = data_instances.mapValues(lambda x: x.label)
- predict_result = predict_result.join(predict_score, lambda x, y: (x, y))
- class_neg_name, class_pos_name = str(class_neg), str(class_pos)
- predict_result = predict_result.join(
- pred_label,
- lambda x, y: [
- x[0],
- y,
- x[1],
- predict_detail_dict_to_str({class_neg_name: (1 - x[1]), class_pos_name: x[1]})
- ],
- )
- # multi-label: input = array of predicted score of all labels
- elif isinstance(classes, list) and len(classes) > 2:
- # pred_label = predict_score.mapValues(lambda x: classes[x.index(max(x))])
- classes = [str(val) for val in classes]
- predict_result = data_instances.mapValues(lambda x: x.label)
- predict_result = predict_result.join(
- predict_score,
- lambda x, y: [
- x,
- int(classes[np.argmax(y)]),
- float(np.max(y)),
- predict_detail_dict_to_str(dict(zip(classes, list(y))))
- ],
- )
- else:
- raise ValueError(
- f"Model's classes type is {type(classes)}, classes must be None or list of length no less than 2."
- )
- def _transfer(instance, pred_res):
- return Instance(features=pred_res, inst_id=instance.inst_id)
- predict_result = data_instances.join(predict_result, _transfer)
- return predict_result
- def callback_meta(self, metric_name, metric_namespace, metric_meta: MetricMeta):
- if self.need_cv:
- metric_name = ".".join([metric_name, str(self.cv_fold)])
- flow_id_list = self.flowid.split(".")
- LOGGER.debug(
- "Need cv, change callback_meta, flow_id_list: {}".format(flow_id_list)
- )
- if len(flow_id_list) > 1:
- curve_name = ".".join(flow_id_list[1:])
- metric_meta.update_metas({"curve_name": curve_name})
- else:
- metric_meta.update_metas({"curve_name": metric_name})
- self.tracker.set_metric_meta(
- metric_name=metric_name,
- metric_namespace=metric_namespace,
- metric_meta=metric_meta,
- )
- def callback_metric(
- self, metric_name, metric_namespace, metric_data: typing.List[Metric]
- ):
- if self.need_cv:
- metric_name = ".".join([metric_name, str(self.cv_fold)])
- self.tracker.log_metric_data(
- metric_name=metric_name,
- metric_namespace=metric_namespace,
- metrics=metric_data,
- )
- def callback_warm_start_init_iter(self, iter_num):
- metric_meta = MetricMeta(
- name="train",
- metric_type="init_iter",
- extra_metas={
- "unit_name": "iters",
- },
- )
- self.callback_meta(
- metric_name="init_iter", metric_namespace="train", metric_meta=metric_meta
- )
- self.callback_metric(
- metric_name="init_iter",
- metric_namespace="train",
- metric_data=[Metric("init_iter", iter_num)],
- )
- def get_latest_checkpoint(self):
- return self.checkpoint_manager.latest_checkpoint.read()
- def save_summary(self):
- self.tracker.log_component_summary(summary_data=self.summary())
- def set_cv_fold(self, cv_fold):
- self.cv_fold = cv_fold
- def summary(self):
- return copy.deepcopy(self._summary)
- def set_summary(self, new_summary):
- """
- Model summary setter
- Parameters
- ----------
- new_summary: dict, summary to replace the original one
- Returns
- -------
- """
- if not isinstance(new_summary, dict):
- raise ValueError(
- f"summary should be of dict type, received {type(new_summary)} instead."
- )
- self._summary = copy.deepcopy(new_summary)
- def add_summary(self, new_key, new_value):
- """
- Add key:value pair to model summary
- Parameters
- ----------
- new_key: str
- new_value: object
- Returns
- -------
- """
- original_value = self._summary.get(new_key, None)
- if original_value is not None:
- LOGGER.warning(
- f"{new_key} already exists in model summary."
- f"Corresponding value {original_value} will be replaced by {new_value}"
- )
- self._summary[new_key] = new_value
- # LOGGER.debug(f"{new_key}: {new_value} added to summary.")
- def merge_summary(self, new_content, suffix=None, suffix_sep="_"):
- """
- Merge new content into model summary
- Parameters
- ----------
- new_content: dict, content to be merged into summary
- suffix: str or None, suffix used to create new key if any key in new_content already exixts in model summary
- suffix_sep: string, default '_', suffix separator used to create new key
- Returns
- -------
- """
- if not isinstance(new_content, dict):
- raise ValueError(
- f"To merge new content into model summary, "
- f"value must be of dict type, received {type(new_content)} instead."
- )
- new_summary = self.summary()
- keyset = new_summary.keys() | new_content.keys()
- for key in keyset:
- if key in new_summary and key in new_content:
- if suffix is not None:
- new_key = f"{key}{suffix_sep}{suffix}"
- else:
- new_key = key
- new_value = new_content.get(key)
- new_summary[new_key] = new_value
- elif key in new_content:
- new_summary[key] = new_content.get(key)
- else:
- pass
- self.set_summary(new_summary)
- @staticmethod
- def extract_data(data: dict):
- LOGGER.debug("In extract_data, data input: {}".format(data))
- if len(data) == 0:
- return data
- if len(data) == 1:
- return list(data.values())[0]
- return data
- @staticmethod
- def check_schema_content(schema):
- """
- check for repeated header & illegal/non-printable chars except for space
- allow non-ascii chars
- :param schema: dict
- :return:
- """
- abnormal_detection.check_legal_schema(schema)
- def align_data_header(self, data_instances, pre_header):
- """
- align features of given data, raise error if value in given schema not found
- :param data_instances: data table
- :param pre_header: list, header of model
- :return: dtable, aligned data
- """
- result_data = self._align_cache.get(id(data_instances))
- if result_data is None:
- result_data = header_alignment(
- data_instances=data_instances, pre_header=pre_header
- )
- self._align_cache[id(data_instances)] = result_data
- return result_data
- @staticmethod
- def pass_data(data):
- if isinstance(data, dict) and len(data) >= 1:
- data = list(data.values())[0]
- return data
- def obtain_data(self, data_list):
- if isinstance(data_list, list):
- return data_list[0]
- return data_list
|