123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import functools
- import numpy as np
- from fate_arch.computing import is_table
- from federatedml.util import LOGGER
- class RunningFuncs(object):
- def __init__(self):
- self.todo_func_list = []
- self.todo_func_params = []
- self.save_result = []
- self.use_previews_result = []
- def add_func(self, func, params, save_result=False, use_previews=False):
- self.todo_func_list.append(func)
- self.todo_func_params.append(params)
- self.save_result.append(save_result)
- self.use_previews_result.append(use_previews)
- def __iter__(self):
- for func, params, save_result, use_previews in zip(
- self.todo_func_list,
- self.todo_func_params,
- self.save_result,
- self.use_previews_result,
- ):
- yield func, params, save_result, use_previews
- class DSLConfigError(ValueError):
- pass
- class ComponentProperties(object):
- def __init__(self):
- self.need_cv = False
- self.need_run = False
- self.need_stepwise = False
- self.has_model = False
- self.has_isometric_model = False
- self.has_train_data = False
- self.has_eval_data = False
- self.has_validate_data = False
- self.has_test_data = False
- self.has_normal_input_data = False
- self.role = None
- self.host_party_idlist = []
- self.local_partyid = -1
- self.guest_partyid = -1
- self.input_data_count = 0
- self.input_eval_data_count = 0
- self.caches = None
- self.is_warm_start = False
- self.has_arbiter = False
- def parse_caches(self, caches):
- self.caches = caches
- def parse_component_param(self, roles, param):
- try:
- need_cv = param.cv_param.need_cv
- except AttributeError:
- need_cv = False
- self.need_cv = need_cv
- try:
- need_run = param.need_run
- except AttributeError:
- need_run = True
- self.need_run = need_run
- LOGGER.debug("need_run: {}, need_cv: {}".format(self.need_run, self.need_cv))
- try:
- need_stepwise = param.stepwise_param.need_stepwise
- except AttributeError:
- need_stepwise = False
- self.need_stepwise = need_stepwise
- self.has_arbiter = roles["role"].get("arbiter") is not None
- self.role = roles["local"]["role"]
- self.host_party_idlist = roles["role"].get("host")
- self.local_partyid = roles["local"].get("party_id")
- self.guest_partyid = roles["role"].get("guest")
- if self.guest_partyid is not None:
- self.guest_partyid = self.guest_partyid[0]
- return self
- def parse_dsl_args(self, datasets, model):
- if "model" in model and model["model"] is not None:
- self.has_model = True
- if "isometric_model" in model and model["isometric_model"] is not None:
- self.has_isometric_model = True
- LOGGER.debug(f"parse_dsl_args data_sets: {datasets}")
- if datasets is None:
- return self
- for data_key, data_dicts in datasets.items():
- data_keys = list(data_dicts.keys())
- for data_type in ["train_data", "eval_data", "validate_data", "test_data"]:
- if data_type in data_keys:
- setattr(self, f"has_{data_type}", True)
- data_keys.remove(data_type)
- LOGGER.debug(
- f"[Data Parser], has_{data_type}:"
- f" {getattr(self, f'has_{data_type}')}"
- )
- if len(data_keys) > 0:
- self.has_normal_input_data = True
- LOGGER.debug(
- "[Data Parser], has_normal_data: {}".format(self.has_normal_input_data)
- )
- if self.has_eval_data:
- if self.has_validate_data or self.has_test_data:
- raise DSLConfigError(
- "eval_data input should not be configured simultaneously"
- " with validate_data or test_data"
- )
- # self._abnormal_dsl_config_detect()
- if self.has_model and self.has_train_data:
- self.is_warm_start = True
- return self
- def _abnormal_dsl_config_detect(self):
- if self.has_validate_data:
- if not self.has_train_data:
- raise DSLConfigError(
- "validate_data should be configured simultaneously"
- " with train_data"
- )
- if self.has_train_data:
- if self.has_normal_input_data or self.has_test_data:
- raise DSLConfigError(
- "train_data input should not be configured simultaneously"
- " with data or test_data"
- )
- if self.has_normal_input_data:
- if self.has_train_data or self.has_validate_data or self.has_test_data:
- raise DSLConfigError(
- "When data input has been configured, train_data, "
- "validate_data or test_data should not be configured."
- )
- if self.has_test_data:
- if not self.has_model:
- raise DSLConfigError(
- "When test_data input has been configured, model "
- "input should be configured too."
- )
- if self.need_cv or self.need_stepwise:
- if not self.has_train_data:
- raise DSLConfigError(
- "Train_data should be configured in cross-validate "
- "task or stepwise task"
- )
- if (
- self.has_validate_data
- or self.has_normal_input_data
- or self.has_test_data
- ):
- raise DSLConfigError(
- "Train_data should be set only if it is a cross-validate "
- "task or a stepwise task"
- )
- if self.has_model or self.has_isometric_model:
- raise DSLConfigError(
- "In cross-validate task or stepwise task, model "
- "or isometric_model should not be configured"
- )
- def extract_input_data(self, datasets, model):
- model_data = {}
- data = {}
- LOGGER.debug(f"Input data_sets: {datasets}")
- for cpn_name, data_dict in datasets.items():
- for data_type in ["train_data", "eval_data", "validate_data", "test_data"]:
- if data_type in data_dict:
- d_table = data_dict.get(data_type)
- model_data[data_type] = model.obtain_data(d_table)
- del data_dict[data_type]
- if len(data_dict) > 0:
- LOGGER.debug(f"data_dict: {data_dict}")
- for k, v in data_dict.items():
- data_list = model.obtain_data(v)
- LOGGER.debug(f"data_list: {data_list}")
- if isinstance(data_list, list):
- for i, data_i in enumerate(data_list):
- data[".".join([cpn_name, k, str(i)])] = data_i
- else:
- data[".".join([cpn_name, k])] = data_list
- train_data = model_data.get("train_data")
- validate_data = None
- if self.has_train_data:
- if self.has_eval_data:
- validate_data = model_data.get("eval_data")
- elif self.has_validate_data:
- validate_data = model_data.get("validate_data")
- test_data = None
- if self.has_test_data:
- test_data = model_data.get("test_data")
- self.has_test_data = True
- elif self.has_eval_data and not self.has_train_data:
- test_data = model_data.get("eval_data")
- self.has_test_data = True
- if validate_data or (self.has_train_data and self.has_eval_data):
- self.has_validate_data = True
- if self.has_train_data and is_table(train_data):
- self.input_data_count = train_data.count()
- elif self.has_normal_input_data:
- for data_key, data_table in data.items():
- if is_table(data_table):
- self.input_data_count = data_table.count()
- if self.has_validate_data and is_table(validate_data):
- self.input_eval_data_count = validate_data.count()
- self._abnormal_dsl_config_detect()
- LOGGER.debug(
- f"train_data: {train_data}, validate_data: {validate_data}, "
- f"test_data: {test_data}, data: {data}"
- )
- return train_data, validate_data, test_data, data
- def warm_start_process(self, running_funcs, model, train_data, validate_data, schema=None):
- if schema is None:
- for d in [train_data, validate_data]:
- if d is not None:
- schema = d.schema
- break
- running_funcs = self._train_process(running_funcs, model, train_data, validate_data,
- test_data=None, schema=schema)
- return running_funcs
- def _train_process(self, running_funcs, model, train_data, validate_data, test_data, schema):
- if self.has_train_data and self.has_validate_data:
- running_funcs.add_func(model.set_flowid, ['fit'])
- running_funcs.add_func(model.fit, [train_data, validate_data])
- running_funcs.add_func(model.set_flowid, ['validate'])
- running_funcs.add_func(model.predict, [train_data], save_result=True)
- running_funcs.add_func(model.set_flowid, ['predict'])
- running_funcs.add_func(model.predict, [validate_data], save_result=True)
- running_funcs.add_func(self.union_data, ["train", "validate"], use_previews=True, save_result=True)
- running_funcs.add_func(model.set_predict_data_schema, [schema],
- use_previews=True, save_result=True)
- elif self.has_train_data:
- running_funcs.add_func(model.set_flowid, ['fit'])
- running_funcs.add_func(model.fit, [train_data])
- running_funcs.add_func(model.set_flowid, ['validate'])
- running_funcs.add_func(model.predict, [train_data], save_result=True)
- running_funcs.add_func(self.union_data, ["train"], use_previews=True, save_result=True)
- running_funcs.add_func(model.set_predict_data_schema, [schema],
- use_previews=True, save_result=True)
- elif self.has_test_data:
- running_funcs.add_func(model.set_flowid, ['predict'])
- running_funcs.add_func(model.predict, [test_data], save_result=True)
- running_funcs.add_func(self.union_data, ["predict"], use_previews=True, save_result=True)
- running_funcs.add_func(model.set_predict_data_schema, [schema],
- use_previews=True, save_result=True)
- return running_funcs
- def extract_running_rules(self, datasets, models, cpn):
- # train_data, eval_data, data = self.extract_input_data(args)
- train_data, validate_data, test_data, data = self.extract_input_data(
- datasets, cpn
- )
- running_funcs = RunningFuncs()
- schema = None
- for d in [train_data, validate_data, test_data]:
- if d is not None:
- schema = d.schema
- break
- if not self.need_run:
- running_funcs.add_func(cpn.pass_data, [data], save_result=True)
- return running_funcs
- if self.need_cv:
- running_funcs.add_func(cpn.cross_validation, [train_data], save_result=True)
- return running_funcs
- if self.need_stepwise:
- running_funcs.add_func(cpn.stepwise, [train_data], save_result=True)
- running_funcs.add_func(self.union_data, ["train"], use_previews=True, save_result=True)
- running_funcs.add_func(cpn.set_predict_data_schema, [schema],
- use_previews=True, save_result=True)
- return running_funcs
- if self.has_model or self.has_isometric_model:
- running_funcs.add_func(cpn.load_model, [models])
- if self.is_warm_start:
- return self.warm_start_process(running_funcs, cpn, train_data, validate_data, schema)
- running_funcs = self._train_process(running_funcs, cpn, train_data, validate_data, test_data, schema)
- if self.has_normal_input_data and not self.has_model:
- running_funcs.add_func(cpn.extract_data, [data], save_result=True)
- running_funcs.add_func(cpn.set_flowid, ['fit'])
- running_funcs.add_func(cpn.fit, [], use_previews=True, save_result=True)
- if self.has_normal_input_data and self.has_model:
- running_funcs.add_func(cpn.extract_data, [data], save_result=True)
- running_funcs.add_func(cpn.set_flowid, ['transform'])
- running_funcs.add_func(cpn.transform, [], use_previews=True, save_result=True)
- return running_funcs
- @staticmethod
- def union_data(previews_data, name_list):
- if len(previews_data) == 0:
- return None
- if any([x is None for x in previews_data]):
- return None
- assert len(previews_data) == len(name_list)
- def _append_name(value, name):
- inst = copy.deepcopy(value)
- if isinstance(inst.features, list):
- inst.features.append(name)
- else:
- inst.features = np.append(inst.features, name)
- return inst
- result_data = None
- for data, name in zip(previews_data, name_list):
- # LOGGER.debug("before mapValues, one data: {}".format(data.first()))
- f = functools.partial(_append_name, name=name)
- data = data.mapValues(f)
- # LOGGER.debug("after mapValues, one data: {}".format(data.first()))
- if result_data is None:
- result_data = data
- else:
- LOGGER.debug(
- f"Before union, t1 count: {result_data.count()}, t2 count: {data.count()}"
- )
- result_data = result_data.union(data)
- LOGGER.debug(f"After union, result count: {result_data.count()}")
- # LOGGER.debug("before out loop, one data: {}".format(result_data.first()))
- return result_data
- def set_union_func(self, func):
- self.union_data = func
|