123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import copy
- from fate_arch.abc import Components
- from fate_flow.component_env_utils import provider_utils
- from fate_flow.entity import ComponentProvider
- from fate_flow.db.component_registry import ComponentRegistry
- class RuntimeConfParserUtil(object):
- @classmethod
- def get_input_parameters(cls, submit_dict, components=None):
- return RuntimeConfParserV2.get_input_parameters(submit_dict, components=components)
- @classmethod
- def get_job_parameters(cls, submit_dict, conf_version=1):
- if conf_version == 1:
- return RuntimeConfParserV1.get_job_parameters(submit_dict)
- else:
- return RuntimeConfParserV2.get_job_parameters(submit_dict)
- @staticmethod
- def merge_dict(dict1, dict2):
- merge_ret = {}
- key_set = dict1.keys() | dict2.keys()
- for key in key_set:
- if key in dict1 and key in dict2:
- val1 = dict1.get(key)
- val2 = dict2.get(key)
- if isinstance(val1, dict):
- merge_ret[key] = RuntimeConfParserUtil.merge_dict(val1, val2)
- else:
- merge_ret[key] = val2
- elif key in dict1:
- merge_ret[key] = dict1.get(key)
- else:
- merge_ret[key] = dict2.get(key)
- return merge_ret
- @staticmethod
- def generate_predict_conf_template(predict_dsl, train_conf, model_id, model_version):
- return RuntimeConfParserV2.generate_predict_conf_template(predict_dsl, train_conf, model_id, model_version)
- @staticmethod
- def get_module_name(module, role, provider: Components):
- return provider.get(module, ComponentRegistry.get_provider_components(provider.provider_name, provider.provider_version)).get_run_obj_name(role)
- @staticmethod
- def get_component_parameters(
- provider,
- runtime_conf,
- module,
- alias,
- redundant_param_check,
- local_role,
- local_party_id,
- parse_user_specified_only,
- pre_parameters=None
- ):
- provider_components = ComponentRegistry.get_provider_components(
- provider.provider_name, provider.provider_version
- )
- support_roles = provider.get(module, provider_components).get_supported_roles()
- if runtime_conf["role"] is not None:
- support_roles = [r for r in runtime_conf["role"] if r in support_roles]
- role_on_module = copy.deepcopy(runtime_conf["role"])
- for role in runtime_conf["role"]:
- if role not in support_roles:
- del role_on_module[role]
- if local_role not in role_on_module:
- return {}
- conf = dict()
- for key, value in runtime_conf.items():
- if key not in [
- "algorithm_parameters",
- "role_parameters",
- "component_parameters",
- ]:
- conf[key] = value
- conf["role"] = role_on_module
- conf["local"] = runtime_conf.get("local", {})
- conf["local"].update({"role": local_role, "party_id": local_party_id})
- conf["module"] = module
- conf["CodePath"] = provider.get(module, provider_components).get_run_obj_name(
- local_role
- )
- param_class = provider.get(module, provider_components).get_param_obj(alias)
- role_idx = role_on_module[local_role].index(local_party_id)
- user_specified_parameters = dict()
- if pre_parameters:
- if parse_user_specified_only:
- user_specified_parameters.update(
- pre_parameters.get("ComponentParam", {})
- )
- else:
- param_class = param_class.update(
- pre_parameters.get("ComponentParam", {})
- )
- common_parameters = (
- runtime_conf.get("component_parameters", {}).get("common", {}).get(alias, {})
- )
- if parse_user_specified_only:
- user_specified_parameters.update(common_parameters)
- else:
- param_class = param_class.update(
- common_parameters, not redundant_param_check
- )
- # update role parameters
- for role_id, role_id_parameters in (
- runtime_conf.get("component_parameters", {})
- .get("role", {})
- .get(local_role, {})
- .items()
- ):
- if role_id == "all" or str(role_idx) in role_id.split("|"):
- parameters = role_id_parameters.get(alias, {})
- if parse_user_specified_only:
- user_specified_parameters.update(parameters)
- else:
- param_class.update(parameters, not redundant_param_check)
- if not parse_user_specified_only:
- conf["ComponentParam"] = param_class.as_dict()
- param_class.check()
- else:
- conf["ComponentParam"] = user_specified_parameters
- return conf
- @staticmethod
- def convert_parameters_v1_to_v2(party_idx, parameter_v1, not_builtin_vars):
- parameter_v2 = {}
- for key, values in parameter_v1.items():
- # stop here, values support to be a list
- if key not in not_builtin_vars:
- parameter_v2[key] = values[party_idx]
- else:
- parameter_v2[key] = RuntimeConfParserUtil.convert_parameters_v1_to_v2(party_idx, values, not_builtin_vars)
- return parameter_v2
- @staticmethod
- def get_v1_role_parameters(provider, component, runtime_conf, dsl):
- component_role_parameters = dict()
- if "role_parameters" not in runtime_conf:
- return component_role_parameters
- role_parameters = runtime_conf["role_parameters"]
- module = dsl["components"][component]["module"]
- if module == "Reader":
- data_key = dsl["components"][component]["output"]["data"][0]
- for role, role_params in role_parameters.items():
- if not role_params.get("args", {}).get("data", {}).get(data_key):
- continue
- component_role_parameters[role] = dict()
- dataset = role_params["args"]["data"][data_key]
- for idx, table in enumerate(dataset):
- component_role_parameters[role][str(idx)] = {component: {"table": table}}
- else:
- provider_components = ComponentRegistry.get_provider_components(
- provider.provider_name, provider.provider_version
- )
- param_class = provider.get(module, provider_components).get_param_obj(component)
- extract_not_builtin = getattr(param_class, "extract_not_builtin", None)
- not_builtin_vars = extract_not_builtin() if extract_not_builtin is not None else {}
- for role, role_params in role_parameters.items():
- params = role_params.get(component, {})
- if not params:
- continue
- component_role_parameters[role] = dict()
- party_num = len(runtime_conf["role"][role])
- for party_idx in range(party_num):
- party_param = RuntimeConfParserUtil.convert_parameters_v1_to_v2(party_idx, params, not_builtin_vars)
- component_role_parameters[role][str(party_idx)] = {component: party_param}
- return component_role_parameters
- @staticmethod
- def get_job_providers_by_dsl(dsl, provider_detail):
- provider_info = {}
- global_provider_name = None
- global_provider_version = None
- if "provider" in dsl:
- global_provider_msg = dsl["provider"].split("@", -1)
- if global_provider_msg[0] == "@" or len(global_provider_msg) > 2:
- raise ValueError("Provider format should be provider_name@provider_version or provider_name, "
- "@provider_version is not supported")
- if len(global_provider_msg) == 1:
- global_provider_name = global_provider_msg[0]
- else:
- global_provider_name, global_provider_version = global_provider_msg
- for component in dsl["components"]:
- module = dsl["components"][component]["module"]
- provider_config = dsl["components"][component].get("provider")
- name, version = RuntimeConfParserUtil.get_component_provider_by_user_conf(component,
- module,
- provider_config,
- provider_detail,
- global_provider_name,
- global_provider_version)
- provider_info.update({component: {
- "module": module,
- "provider": {
- "name": name,
- "version": version
- }
- }})
- return provider_info
- @classmethod
- def get_job_providers(cls, dsl, provider_detail, submit_dict=None, local_role=None, local_party_id=None):
- provider_info = cls.get_job_providers_by_dsl(dsl, provider_detail)
- if submit_dict is None:
- return provider_info
- else:
- if local_party_id is None or local_role is None \
- or local_role not in submit_dict["role"] or \
- (str(local_party_id) not in submit_dict["role"][local_role]
- and int(local_party_id) not in submit_dict["role"][local_role]):
- raise ValueError("when parse provider from conf, local role & party_id should should be None")
- provider_info_all_party = {}
- dsl_version = submit_dict.get("dsl_version", 1)
- if dsl_version == 1 or "provider" not in submit_dict:
- for role in submit_dict["role"]:
- party_id_list = submit_dict["role"][role]
- provider_info_all_party[role] = {party_id: dict() for party_id in party_id_list}
- provider_info_all_party[local_role][local_party_id] = provider_info
- else:
- provider_config = submit_dict["provider"]
- common_provider_config = provider_config.get("common", {})
- other_party_provider_config = dict()
- if common_provider_config:
- for component, provider_msg in common_provider_config.items():
- if component not in provider_info:
- raise ValueError(f"Redundant omponent {component} is not found in dsl")
- module = provider_info[component]["module"]
- name, version = cls.get_component_provider_by_user_conf(component,
- module,
- provider_msg,
- provider_detail)
- provider_info[component]["provider"] = dict(name=name, version=version)
- other_name, other_version = cls.get_component_provider_by_user_conf(component,
- module,
- provider_msg)
- other_party_provider_config[component] = {
- "module": module,
- "provider": {
- "name": other_name,
- "version": other_version
- }
- }
- provider_info_all_party[local_role]= {local_party_id : copy.deepcopy(provider_info)}
- for role in submit_dict["role"]:
- if role not in provider_info_all_party:
- provider_info_all_party[role] = {}
- role_provider_config = provider_config.get("role", {}).get(role, {})
- for idx, party_id in enumerate(submit_dict["role"][role]):
- if role == local_role and party_id == local_party_id:
- provider_info_party = copy.deepcopy(provider_info)
- else:
- provider_info_party = copy.deepcopy(other_party_provider_config)
- for role_id, role_id_provider_config in role_provider_config.items():
- if role_id == "all" or str(idx) in role_id.split("|", -1):
- for component, provider_msg in role_id_provider_config.items():
- module = dsl["components"][component]["module"]
- detail_info = provider_detail if role == role and party_id == local_party_id else None
- name, version = cls.get_component_provider_by_user_conf(component,
- module,
- provider_msg,
- provider_detail=detail_info)
- if component not in provider_info_party:
- provider_info_party[component] = dict(module=module)
- provider_info_party[component]["provider"] = dict(name=name, version=version)
- provider_info_all_party[role][party_id] = provider_info_party
- return provider_info_all_party
- @staticmethod
- def get_component_provider_by_user_conf(component, module, provider_config, provider_detail=None,
- default_name=None, default_version=None):
- name, version = None, None
- if provider_config:
- provider_msg = provider_config.split("@", -1)
- if provider_config[0] == "@" or len(provider_msg) > 2:
- raise ValueError("Provider format should be provider_name@provider_version or provider_name, "
- "@provider_version is not supported")
- if len(provider_msg) == 2:
- name, version = provider_config.split("@", -1)
- else:
- name = provider_msg[0]
- if not name:
- if default_name:
- name = default_name
- version = default_version
- if provider_detail is None:
- return name, version
- if name and name not in provider_detail["components"][module]["support_provider"]:
- raise ValueError(f"Provider: {name} does not support in {module}, please register")
- if version and version not in provider_detail["providers"][name]:
- raise ValueError(f"Provider: {name} version: {version} does not support in {module}, please register")
- if name and not version:
- version = RuntimeConfParserUtil.get_component_provider(alias=component,
- module=module,
- provider_detail=provider_detail,
- name=name)
- elif not name and not version:
- name, version = RuntimeConfParserUtil.get_component_provider(alias=component,
- module=module,
- provider_detail=provider_detail)
- return name, version
- @staticmethod
- def get_component_provider(alias, module, provider_detail, detect=True, name=None):
- if module not in provider_detail["components"]:
- if detect:
- raise ValueError(f"component {alias}, module {module}'s provider does not exist")
- else:
- return None
- if name is None:
- name = provider_detail["components"][module]["default_provider"]
- version = provider_detail["providers"][name]["default"]["version"]
- return name, version
- else:
- if name not in provider_detail["components"][module]["support_provider"]:
- raise ValueError(f"Provider {name} does not support, please register in fate-flow")
- version = provider_detail["providers"][name]["default"]["version"]
- return version
- @staticmethod
- def instantiate_component_provider(provider_detail, alias=None, module=None, provider_name=None,
- provider_version=None, local_role=None, local_party_id=None,
- detect=True, provider_cache=None, job_parameters=None):
- if provider_name and provider_version:
- provider_path = provider_detail["providers"][provider_name][provider_version]["path"]
- provider = provider_utils.get_provider_interface(ComponentProvider(name=provider_name,
- version=provider_version,
- path=provider_path,
- class_path=ComponentRegistry.get_default_class_path()))
- if provider_cache is not None:
- if provider_name not in provider_cache:
- provider_cache[provider_name] = {}
- provider_cache[provider_name][provider_version] = provider
- return provider
- provider_name, provider_version = RuntimeConfParserUtil.get_component_provider(alias=alias,
- module=module,
- provider_detail=provider_detail,
- detect=detect)
- return RuntimeConfParserUtil.instantiate_component_provider(provider_detail,
- provider_name=provider_name,
- provider_version=provider_version)
- @classmethod
- def merge_predict_runtime_conf(cls, train_conf, predict_conf):
- runtime_conf = copy.deepcopy(train_conf)
- train_role = train_conf.get("role")
- predict_role = predict_conf.get("role")
- if len(train_conf) < len(predict_role):
- raise ValueError(f"Predict roles is {predict_role}, train roles is {train_conf}, "
- "predict roles should be subset of train role")
- for role in train_role:
- if role not in predict_role:
- del runtime_conf["role"][role]
- if runtime_conf.get("job_parameters", {}).get("role", {}).get(role):
- del runtime_conf["job_parameters"]["role"][role]
- if runtime_conf.get("component_parameters", {}).get("role", {}).get(role):
- del runtime_conf["component_parameters"]["role"][role]
- continue
- train_party_ids = train_role[role]
- predict_party_ids = predict_role[role]
- diff = False
- for idx, party_id in enumerate(predict_party_ids):
- if party_id not in train_party_ids:
- raise ValueError(f"Predict role: {role} party_id: {party_id} not occurs in training")
- if train_party_ids[idx] != party_id:
- diff = True
- if not diff and len(train_party_ids) == len(predict_party_ids):
- continue
- for p_type in ["job_parameters", "component_parameters"]:
- if not runtime_conf.get(p_type, {}).get("role", {}).get(role):
- continue
- conf = runtime_conf[p_type]["role"][role]
- party_keys = conf.keys()
- new_conf = {}
- for party_key in party_keys:
- party_list = party_key.split("|", -1)
- new_party_list = []
- for party in party_list:
- party_id = train_party_ids[int(party)]
- if party_id in predict_party_ids:
- new_idx = predict_party_ids.index(party_id)
- new_party_list.append(str(new_idx))
- if not new_party_list:
- continue
- new_party_key = new_party_list[0] if len(new_party_list) == 1 else "|".join(new_party_list)
- if new_party_key not in new_conf:
- new_conf[new_party_key] = {}
- new_conf[new_party_key].update(conf[party_key])
- runtime_conf[p_type]["role"][role] = new_conf
- runtime_conf = cls.merge_dict(runtime_conf, predict_conf)
- return runtime_conf
- @staticmethod
- def get_model_loader_alias(component_name, runtime_conf, local_role, local_party_id):
- role_params = runtime_conf.get("component_parameters", {}).get("role", {}).get("local_role")
- if not role_params:
- return runtime_conf.get("component_parameters", {}).\
- get("common", {}).get(component_name, {}).get("component_name")
- party_idx = runtime_conf.get("role").get(local_role).index(local_party_id)
- for id_list, params in role_params.times():
- ids = id_list.split("|", -1)
- if ids == "all" or str(party_idx) in ids:
- if params.get(component_name, {}).get("component_name"):
- model_load_alias = params.get(component_name, {}).get("component_name")
- return model_load_alias
- return runtime_conf.get("component_parameters", {}). \
- get("common", {}).get(component_name, {}).get("component_name")
- class RuntimeConfParserV1(object):
- @staticmethod
- def get_job_parameters(submit_dict):
- ret = {}
- job_parameters = submit_dict.get("job_parameters", {})
- for role in submit_dict["role"]:
- party_id_list = submit_dict["role"][role]
- ret[role] = {party_id: copy.deepcopy(job_parameters) for party_id in party_id_list}
- return ret
- class RuntimeConfParserV2(object):
- @classmethod
- def get_input_parameters(cls, submit_dict, components=None):
- if submit_dict.get("component_parameters", {}).get("role") is None or components is None:
- return {}
- roles = submit_dict["component_parameters"]["role"].keys()
- if not roles:
- return {}
- input_parameters = {"dsl_version": 2}
- cpn_dict = {}
- for reader_cpn in components:
- cpn_dict[reader_cpn] = {}
- for role in roles:
- role_parameters = submit_dict["component_parameters"]["role"][role]
- input_parameters[role] = [copy.deepcopy(cpn_dict)] * len(submit_dict["role"][role])
- for idx, parameters in role_parameters.items():
- for reader in components:
- if reader not in parameters:
- continue
- if idx == "all":
- party_id_list = submit_dict["role"][role]
- for i in range(len(party_id_list)):
- input_parameters[role][i][reader] = parameters[reader]
- elif len(idx.split("|")) == 1:
- input_parameters[role][int(idx)][reader] = parameters[reader]
- else:
- id_set = list(map(int, idx.split("|")))
- for _id in id_set:
- input_parameters[role][_id][reader] = parameters[reader]
- return input_parameters
- @staticmethod
- def get_job_parameters(submit_dict):
- ret = {}
- job_parameters = submit_dict.get("job_parameters", {})
- common_job_parameters = job_parameters.get("common", {})
- role_job_parameters = job_parameters.get("role", {})
- for role in submit_dict["role"]:
- party_id_list = submit_dict["role"][role]
- if not role_job_parameters:
- ret[role] = {party_id: copy.deepcopy(common_job_parameters) for party_id in party_id_list}
- continue
- ret[role] = {}
- for idx in range(len(party_id_list)):
- role_ids = role_job_parameters.get(role, {}).keys()
- parameters = copy.deepcopy(common_job_parameters)
- for role_id in role_ids:
- if role_id == "all" or str(idx) in role_id.split("|"):
- parameters = RuntimeConfParserUtil.merge_dict(parameters,
- role_job_parameters.get(role, {})[role_id])
- ret[role][party_id_list[idx]] = parameters
- return ret
- @staticmethod
- def generate_predict_conf_template(predict_dsl, train_conf, model_id, model_version):
- if not train_conf.get("role") or not train_conf.get("initiator"):
- raise ValueError("role and initiator should be contain in job's trainconf")
- predict_conf = dict()
- predict_conf["dsl_version"] = 2
- predict_conf["role"] = train_conf.get("role")
- predict_conf["initiator"] = train_conf.get("initiator")
- predict_conf["job_parameters"] = train_conf.get("job_parameters", {})
- predict_conf["job_parameters"]["common"].update({"model_id": model_id,
- "model_version": model_version,
- "job_type": "predict"})
- predict_conf["component_parameters"] = {"role": {}}
- for role in predict_conf["role"]:
- if role not in ["guest", "host"]:
- continue
- reader_components = []
- for module_alias, module_info in predict_dsl.get("components", {}).items():
- if module_info["module"] == "Reader":
- reader_components.append(module_alias)
- predict_conf["component_parameters"]["role"][role] = dict()
- fill_template = {}
- for idx, reader_alias in enumerate(reader_components):
- fill_template[reader_alias] = {"table": {"name": "name_to_be_filled_" + str(idx),
- "namespace": "namespace_to_be_filled_" + str(idx)}}
- for idx in range(len(predict_conf["role"][role])):
- predict_conf["component_parameters"]["role"][role][str(idx)] = fill_template
- return predict_conf
|