runtime_conf_parse_util.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. from fate_arch.abc import Components
  18. from fate_flow.component_env_utils import provider_utils
  19. from fate_flow.entity import ComponentProvider
  20. from fate_flow.db.component_registry import ComponentRegistry
  21. class RuntimeConfParserUtil(object):
  22. @classmethod
  23. def get_input_parameters(cls, submit_dict, components=None):
  24. return RuntimeConfParserV2.get_input_parameters(submit_dict, components=components)
  25. @classmethod
  26. def get_job_parameters(cls, submit_dict, conf_version=1):
  27. if conf_version == 1:
  28. return RuntimeConfParserV1.get_job_parameters(submit_dict)
  29. else:
  30. return RuntimeConfParserV2.get_job_parameters(submit_dict)
  31. @staticmethod
  32. def merge_dict(dict1, dict2):
  33. merge_ret = {}
  34. key_set = dict1.keys() | dict2.keys()
  35. for key in key_set:
  36. if key in dict1 and key in dict2:
  37. val1 = dict1.get(key)
  38. val2 = dict2.get(key)
  39. if isinstance(val1, dict):
  40. merge_ret[key] = RuntimeConfParserUtil.merge_dict(val1, val2)
  41. else:
  42. merge_ret[key] = val2
  43. elif key in dict1:
  44. merge_ret[key] = dict1.get(key)
  45. else:
  46. merge_ret[key] = dict2.get(key)
  47. return merge_ret
  48. @staticmethod
  49. def generate_predict_conf_template(predict_dsl, train_conf, model_id, model_version):
  50. return RuntimeConfParserV2.generate_predict_conf_template(predict_dsl, train_conf, model_id, model_version)
  51. @staticmethod
  52. def get_module_name(module, role, provider: Components):
  53. return provider.get(module, ComponentRegistry.get_provider_components(provider.provider_name, provider.provider_version)).get_run_obj_name(role)
  54. @staticmethod
  55. def get_component_parameters(
  56. provider,
  57. runtime_conf,
  58. module,
  59. alias,
  60. redundant_param_check,
  61. local_role,
  62. local_party_id,
  63. parse_user_specified_only,
  64. pre_parameters=None
  65. ):
  66. provider_components = ComponentRegistry.get_provider_components(
  67. provider.provider_name, provider.provider_version
  68. )
  69. support_roles = provider.get(module, provider_components).get_supported_roles()
  70. if runtime_conf["role"] is not None:
  71. support_roles = [r for r in runtime_conf["role"] if r in support_roles]
  72. role_on_module = copy.deepcopy(runtime_conf["role"])
  73. for role in runtime_conf["role"]:
  74. if role not in support_roles:
  75. del role_on_module[role]
  76. if local_role not in role_on_module:
  77. return {}
  78. conf = dict()
  79. for key, value in runtime_conf.items():
  80. if key not in [
  81. "algorithm_parameters",
  82. "role_parameters",
  83. "component_parameters",
  84. ]:
  85. conf[key] = value
  86. conf["role"] = role_on_module
  87. conf["local"] = runtime_conf.get("local", {})
  88. conf["local"].update({"role": local_role, "party_id": local_party_id})
  89. conf["module"] = module
  90. conf["CodePath"] = provider.get(module, provider_components).get_run_obj_name(
  91. local_role
  92. )
  93. param_class = provider.get(module, provider_components).get_param_obj(alias)
  94. role_idx = role_on_module[local_role].index(local_party_id)
  95. user_specified_parameters = dict()
  96. if pre_parameters:
  97. if parse_user_specified_only:
  98. user_specified_parameters.update(
  99. pre_parameters.get("ComponentParam", {})
  100. )
  101. else:
  102. param_class = param_class.update(
  103. pre_parameters.get("ComponentParam", {})
  104. )
  105. common_parameters = (
  106. runtime_conf.get("component_parameters", {}).get("common", {}).get(alias, {})
  107. )
  108. if parse_user_specified_only:
  109. user_specified_parameters.update(common_parameters)
  110. else:
  111. param_class = param_class.update(
  112. common_parameters, not redundant_param_check
  113. )
  114. # update role parameters
  115. for role_id, role_id_parameters in (
  116. runtime_conf.get("component_parameters", {})
  117. .get("role", {})
  118. .get(local_role, {})
  119. .items()
  120. ):
  121. if role_id == "all" or str(role_idx) in role_id.split("|"):
  122. parameters = role_id_parameters.get(alias, {})
  123. if parse_user_specified_only:
  124. user_specified_parameters.update(parameters)
  125. else:
  126. param_class.update(parameters, not redundant_param_check)
  127. if not parse_user_specified_only:
  128. conf["ComponentParam"] = param_class.as_dict()
  129. param_class.check()
  130. else:
  131. conf["ComponentParam"] = user_specified_parameters
  132. return conf
  133. @staticmethod
  134. def convert_parameters_v1_to_v2(party_idx, parameter_v1, not_builtin_vars):
  135. parameter_v2 = {}
  136. for key, values in parameter_v1.items():
  137. # stop here, values support to be a list
  138. if key not in not_builtin_vars:
  139. parameter_v2[key] = values[party_idx]
  140. else:
  141. parameter_v2[key] = RuntimeConfParserUtil.convert_parameters_v1_to_v2(party_idx, values, not_builtin_vars)
  142. return parameter_v2
  143. @staticmethod
  144. def get_v1_role_parameters(provider, component, runtime_conf, dsl):
  145. component_role_parameters = dict()
  146. if "role_parameters" not in runtime_conf:
  147. return component_role_parameters
  148. role_parameters = runtime_conf["role_parameters"]
  149. module = dsl["components"][component]["module"]
  150. if module == "Reader":
  151. data_key = dsl["components"][component]["output"]["data"][0]
  152. for role, role_params in role_parameters.items():
  153. if not role_params.get("args", {}).get("data", {}).get(data_key):
  154. continue
  155. component_role_parameters[role] = dict()
  156. dataset = role_params["args"]["data"][data_key]
  157. for idx, table in enumerate(dataset):
  158. component_role_parameters[role][str(idx)] = {component: {"table": table}}
  159. else:
  160. provider_components = ComponentRegistry.get_provider_components(
  161. provider.provider_name, provider.provider_version
  162. )
  163. param_class = provider.get(module, provider_components).get_param_obj(component)
  164. extract_not_builtin = getattr(param_class, "extract_not_builtin", None)
  165. not_builtin_vars = extract_not_builtin() if extract_not_builtin is not None else {}
  166. for role, role_params in role_parameters.items():
  167. params = role_params.get(component, {})
  168. if not params:
  169. continue
  170. component_role_parameters[role] = dict()
  171. party_num = len(runtime_conf["role"][role])
  172. for party_idx in range(party_num):
  173. party_param = RuntimeConfParserUtil.convert_parameters_v1_to_v2(party_idx, params, not_builtin_vars)
  174. component_role_parameters[role][str(party_idx)] = {component: party_param}
  175. return component_role_parameters
  176. @staticmethod
  177. def get_job_providers_by_dsl(dsl, provider_detail):
  178. provider_info = {}
  179. global_provider_name = None
  180. global_provider_version = None
  181. if "provider" in dsl:
  182. global_provider_msg = dsl["provider"].split("@", -1)
  183. if global_provider_msg[0] == "@" or len(global_provider_msg) > 2:
  184. raise ValueError("Provider format should be provider_name@provider_version or provider_name, "
  185. "@provider_version is not supported")
  186. if len(global_provider_msg) == 1:
  187. global_provider_name = global_provider_msg[0]
  188. else:
  189. global_provider_name, global_provider_version = global_provider_msg
  190. for component in dsl["components"]:
  191. module = dsl["components"][component]["module"]
  192. provider_config = dsl["components"][component].get("provider")
  193. name, version = RuntimeConfParserUtil.get_component_provider_by_user_conf(component,
  194. module,
  195. provider_config,
  196. provider_detail,
  197. global_provider_name,
  198. global_provider_version)
  199. provider_info.update({component: {
  200. "module": module,
  201. "provider": {
  202. "name": name,
  203. "version": version
  204. }
  205. }})
  206. return provider_info
  207. @classmethod
  208. def get_job_providers(cls, dsl, provider_detail, submit_dict=None, local_role=None, local_party_id=None):
  209. provider_info = cls.get_job_providers_by_dsl(dsl, provider_detail)
  210. if submit_dict is None:
  211. return provider_info
  212. else:
  213. if local_party_id is None or local_role is None \
  214. or local_role not in submit_dict["role"] or \
  215. (str(local_party_id) not in submit_dict["role"][local_role]
  216. and int(local_party_id) not in submit_dict["role"][local_role]):
  217. raise ValueError("when parse provider from conf, local role & party_id should should be None")
  218. provider_info_all_party = {}
  219. dsl_version = submit_dict.get("dsl_version", 1)
  220. if dsl_version == 1 or "provider" not in submit_dict:
  221. for role in submit_dict["role"]:
  222. party_id_list = submit_dict["role"][role]
  223. provider_info_all_party[role] = {party_id: dict() for party_id in party_id_list}
  224. provider_info_all_party[local_role][local_party_id] = provider_info
  225. else:
  226. provider_config = submit_dict["provider"]
  227. common_provider_config = provider_config.get("common", {})
  228. other_party_provider_config = dict()
  229. if common_provider_config:
  230. for component, provider_msg in common_provider_config.items():
  231. if component not in provider_info:
  232. raise ValueError(f"Redundant omponent {component} is not found in dsl")
  233. module = provider_info[component]["module"]
  234. name, version = cls.get_component_provider_by_user_conf(component,
  235. module,
  236. provider_msg,
  237. provider_detail)
  238. provider_info[component]["provider"] = dict(name=name, version=version)
  239. other_name, other_version = cls.get_component_provider_by_user_conf(component,
  240. module,
  241. provider_msg)
  242. other_party_provider_config[component] = {
  243. "module": module,
  244. "provider": {
  245. "name": other_name,
  246. "version": other_version
  247. }
  248. }
  249. provider_info_all_party[local_role]= {local_party_id : copy.deepcopy(provider_info)}
  250. for role in submit_dict["role"]:
  251. if role not in provider_info_all_party:
  252. provider_info_all_party[role] = {}
  253. role_provider_config = provider_config.get("role", {}).get(role, {})
  254. for idx, party_id in enumerate(submit_dict["role"][role]):
  255. if role == local_role and party_id == local_party_id:
  256. provider_info_party = copy.deepcopy(provider_info)
  257. else:
  258. provider_info_party = copy.deepcopy(other_party_provider_config)
  259. for role_id, role_id_provider_config in role_provider_config.items():
  260. if role_id == "all" or str(idx) in role_id.split("|", -1):
  261. for component, provider_msg in role_id_provider_config.items():
  262. module = dsl["components"][component]["module"]
  263. detail_info = provider_detail if role == role and party_id == local_party_id else None
  264. name, version = cls.get_component_provider_by_user_conf(component,
  265. module,
  266. provider_msg,
  267. provider_detail=detail_info)
  268. if component not in provider_info_party:
  269. provider_info_party[component] = dict(module=module)
  270. provider_info_party[component]["provider"] = dict(name=name, version=version)
  271. provider_info_all_party[role][party_id] = provider_info_party
  272. return provider_info_all_party
  273. @staticmethod
  274. def get_component_provider_by_user_conf(component, module, provider_config, provider_detail=None,
  275. default_name=None, default_version=None):
  276. name, version = None, None
  277. if provider_config:
  278. provider_msg = provider_config.split("@", -1)
  279. if provider_config[0] == "@" or len(provider_msg) > 2:
  280. raise ValueError("Provider format should be provider_name@provider_version or provider_name, "
  281. "@provider_version is not supported")
  282. if len(provider_msg) == 2:
  283. name, version = provider_config.split("@", -1)
  284. else:
  285. name = provider_msg[0]
  286. if not name:
  287. if default_name:
  288. name = default_name
  289. version = default_version
  290. if provider_detail is None:
  291. return name, version
  292. if name and name not in provider_detail["components"][module]["support_provider"]:
  293. raise ValueError(f"Provider: {name} does not support in {module}, please register")
  294. if version and version not in provider_detail["providers"][name]:
  295. raise ValueError(f"Provider: {name} version: {version} does not support in {module}, please register")
  296. if name and not version:
  297. version = RuntimeConfParserUtil.get_component_provider(alias=component,
  298. module=module,
  299. provider_detail=provider_detail,
  300. name=name)
  301. elif not name and not version:
  302. name, version = RuntimeConfParserUtil.get_component_provider(alias=component,
  303. module=module,
  304. provider_detail=provider_detail)
  305. return name, version
  306. @staticmethod
  307. def get_component_provider(alias, module, provider_detail, detect=True, name=None):
  308. if module not in provider_detail["components"]:
  309. if detect:
  310. raise ValueError(f"component {alias}, module {module}'s provider does not exist")
  311. else:
  312. return None
  313. if name is None:
  314. name = provider_detail["components"][module]["default_provider"]
  315. version = provider_detail["providers"][name]["default"]["version"]
  316. return name, version
  317. else:
  318. if name not in provider_detail["components"][module]["support_provider"]:
  319. raise ValueError(f"Provider {name} does not support, please register in fate-flow")
  320. version = provider_detail["providers"][name]["default"]["version"]
  321. return version
  322. @staticmethod
  323. def instantiate_component_provider(provider_detail, alias=None, module=None, provider_name=None,
  324. provider_version=None, local_role=None, local_party_id=None,
  325. detect=True, provider_cache=None, job_parameters=None):
  326. if provider_name and provider_version:
  327. provider_path = provider_detail["providers"][provider_name][provider_version]["path"]
  328. provider = provider_utils.get_provider_interface(ComponentProvider(name=provider_name,
  329. version=provider_version,
  330. path=provider_path,
  331. class_path=ComponentRegistry.get_default_class_path()))
  332. if provider_cache is not None:
  333. if provider_name not in provider_cache:
  334. provider_cache[provider_name] = {}
  335. provider_cache[provider_name][provider_version] = provider
  336. return provider
  337. provider_name, provider_version = RuntimeConfParserUtil.get_component_provider(alias=alias,
  338. module=module,
  339. provider_detail=provider_detail,
  340. detect=detect)
  341. return RuntimeConfParserUtil.instantiate_component_provider(provider_detail,
  342. provider_name=provider_name,
  343. provider_version=provider_version)
  344. @classmethod
  345. def merge_predict_runtime_conf(cls, train_conf, predict_conf):
  346. runtime_conf = copy.deepcopy(train_conf)
  347. train_role = train_conf.get("role")
  348. predict_role = predict_conf.get("role")
  349. if len(train_conf) < len(predict_role):
  350. raise ValueError(f"Predict roles is {predict_role}, train roles is {train_conf}, "
  351. "predict roles should be subset of train role")
  352. for role in train_role:
  353. if role not in predict_role:
  354. del runtime_conf["role"][role]
  355. if runtime_conf.get("job_parameters", {}).get("role", {}).get(role):
  356. del runtime_conf["job_parameters"]["role"][role]
  357. if runtime_conf.get("component_parameters", {}).get("role", {}).get(role):
  358. del runtime_conf["component_parameters"]["role"][role]
  359. continue
  360. train_party_ids = train_role[role]
  361. predict_party_ids = predict_role[role]
  362. diff = False
  363. for idx, party_id in enumerate(predict_party_ids):
  364. if party_id not in train_party_ids:
  365. raise ValueError(f"Predict role: {role} party_id: {party_id} not occurs in training")
  366. if train_party_ids[idx] != party_id:
  367. diff = True
  368. if not diff and len(train_party_ids) == len(predict_party_ids):
  369. continue
  370. for p_type in ["job_parameters", "component_parameters"]:
  371. if not runtime_conf.get(p_type, {}).get("role", {}).get(role):
  372. continue
  373. conf = runtime_conf[p_type]["role"][role]
  374. party_keys = conf.keys()
  375. new_conf = {}
  376. for party_key in party_keys:
  377. party_list = party_key.split("|", -1)
  378. new_party_list = []
  379. for party in party_list:
  380. party_id = train_party_ids[int(party)]
  381. if party_id in predict_party_ids:
  382. new_idx = predict_party_ids.index(party_id)
  383. new_party_list.append(str(new_idx))
  384. if not new_party_list:
  385. continue
  386. new_party_key = new_party_list[0] if len(new_party_list) == 1 else "|".join(new_party_list)
  387. if new_party_key not in new_conf:
  388. new_conf[new_party_key] = {}
  389. new_conf[new_party_key].update(conf[party_key])
  390. runtime_conf[p_type]["role"][role] = new_conf
  391. runtime_conf = cls.merge_dict(runtime_conf, predict_conf)
  392. return runtime_conf
  393. @staticmethod
  394. def get_model_loader_alias(component_name, runtime_conf, local_role, local_party_id):
  395. role_params = runtime_conf.get("component_parameters", {}).get("role", {}).get("local_role")
  396. if not role_params:
  397. return runtime_conf.get("component_parameters", {}).\
  398. get("common", {}).get(component_name, {}).get("component_name")
  399. party_idx = runtime_conf.get("role").get(local_role).index(local_party_id)
  400. for id_list, params in role_params.times():
  401. ids = id_list.split("|", -1)
  402. if ids == "all" or str(party_idx) in ids:
  403. if params.get(component_name, {}).get("component_name"):
  404. model_load_alias = params.get(component_name, {}).get("component_name")
  405. return model_load_alias
  406. return runtime_conf.get("component_parameters", {}). \
  407. get("common", {}).get(component_name, {}).get("component_name")
  408. class RuntimeConfParserV1(object):
  409. @staticmethod
  410. def get_job_parameters(submit_dict):
  411. ret = {}
  412. job_parameters = submit_dict.get("job_parameters", {})
  413. for role in submit_dict["role"]:
  414. party_id_list = submit_dict["role"][role]
  415. ret[role] = {party_id: copy.deepcopy(job_parameters) for party_id in party_id_list}
  416. return ret
  417. class RuntimeConfParserV2(object):
  418. @classmethod
  419. def get_input_parameters(cls, submit_dict, components=None):
  420. if submit_dict.get("component_parameters", {}).get("role") is None or components is None:
  421. return {}
  422. roles = submit_dict["component_parameters"]["role"].keys()
  423. if not roles:
  424. return {}
  425. input_parameters = {"dsl_version": 2}
  426. cpn_dict = {}
  427. for reader_cpn in components:
  428. cpn_dict[reader_cpn] = {}
  429. for role in roles:
  430. role_parameters = submit_dict["component_parameters"]["role"][role]
  431. input_parameters[role] = [copy.deepcopy(cpn_dict)] * len(submit_dict["role"][role])
  432. for idx, parameters in role_parameters.items():
  433. for reader in components:
  434. if reader not in parameters:
  435. continue
  436. if idx == "all":
  437. party_id_list = submit_dict["role"][role]
  438. for i in range(len(party_id_list)):
  439. input_parameters[role][i][reader] = parameters[reader]
  440. elif len(idx.split("|")) == 1:
  441. input_parameters[role][int(idx)][reader] = parameters[reader]
  442. else:
  443. id_set = list(map(int, idx.split("|")))
  444. for _id in id_set:
  445. input_parameters[role][_id][reader] = parameters[reader]
  446. return input_parameters
  447. @staticmethod
  448. def get_job_parameters(submit_dict):
  449. ret = {}
  450. job_parameters = submit_dict.get("job_parameters", {})
  451. common_job_parameters = job_parameters.get("common", {})
  452. role_job_parameters = job_parameters.get("role", {})
  453. for role in submit_dict["role"]:
  454. party_id_list = submit_dict["role"][role]
  455. if not role_job_parameters:
  456. ret[role] = {party_id: copy.deepcopy(common_job_parameters) for party_id in party_id_list}
  457. continue
  458. ret[role] = {}
  459. for idx in range(len(party_id_list)):
  460. role_ids = role_job_parameters.get(role, {}).keys()
  461. parameters = copy.deepcopy(common_job_parameters)
  462. for role_id in role_ids:
  463. if role_id == "all" or str(idx) in role_id.split("|"):
  464. parameters = RuntimeConfParserUtil.merge_dict(parameters,
  465. role_job_parameters.get(role, {})[role_id])
  466. ret[role][party_id_list[idx]] = parameters
  467. return ret
  468. @staticmethod
  469. def generate_predict_conf_template(predict_dsl, train_conf, model_id, model_version):
  470. if not train_conf.get("role") or not train_conf.get("initiator"):
  471. raise ValueError("role and initiator should be contain in job's trainconf")
  472. predict_conf = dict()
  473. predict_conf["dsl_version"] = 2
  474. predict_conf["role"] = train_conf.get("role")
  475. predict_conf["initiator"] = train_conf.get("initiator")
  476. predict_conf["job_parameters"] = train_conf.get("job_parameters", {})
  477. predict_conf["job_parameters"]["common"].update({"model_id": model_id,
  478. "model_version": model_version,
  479. "job_type": "predict"})
  480. predict_conf["component_parameters"] = {"role": {}}
  481. for role in predict_conf["role"]:
  482. if role not in ["guest", "host"]:
  483. continue
  484. reader_components = []
  485. for module_alias, module_info in predict_dsl.get("components", {}).items():
  486. if module_info["module"] == "Reader":
  487. reader_components.append(module_alias)
  488. predict_conf["component_parameters"]["role"][role] = dict()
  489. fill_template = {}
  490. for idx, reader_alias in enumerate(reader_components):
  491. fill_template[reader_alias] = {"table": {"name": "name_to_be_filled_" + str(idx),
  492. "namespace": "namespace_to_be_filled_" + str(idx)}}
  493. for idx in range(len(predict_conf["role"][role])):
  494. predict_conf["component_parameters"]["role"][role][str(idx)] = fill_template
  495. return predict_conf