dsl_parser.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. ################################################################################
  19. #
  20. #
  21. ################################################################################
  22. # =============================================================================
  23. # DSL PARSER
  24. # =============================================================================
  25. import copy
  26. import json
  27. from fate_flow.settings import stat_logger
  28. from fate_flow.utils.dsl_exception import DSLNotExistError, ComponentFieldNotExistError, \
  29. ModuleFieldNotExistError, ComponentInputTypeError, \
  30. InputComponentNotExistError, InputNameNotExistError, ComponentInputDataTypeError, \
  31. ComponentInputValueTypeError, \
  32. ComponentNotExistError, ModeError, DataNotExistInSubmitConfError, ComponentOutputTypeError, \
  33. ComponentOutputKeyTypeError, LoopError, ComponentMultiMappingError, NamingIndexError, \
  34. NamingError, NamingFormatError, DeployComponentNotExistError, ModuleNotExistError
  35. from fate_flow.utils.runtime_conf_parse_util import RuntimeConfParserUtil
  36. ComponentParameterSource = "ComponentParameterSource"
  37. class Component(object):
  38. def __init__(self):
  39. self.module = None
  40. self.name = None
  41. self.upstream = []
  42. self.downstream = []
  43. self.role_parameters = {}
  44. self.input = {}
  45. self.output = {}
  46. self.component_provider = None
  47. def copy(self):
  48. copy_obj = Component()
  49. copy_obj.set_module(self.module)
  50. copy_obj.set_name(self.name)
  51. copy_obj.set_input(self.input)
  52. copy_obj.set_downstream(self.downstream)
  53. copy_obj.set_upstream(self.upstream)
  54. copy_obj.set_role_parameters(self.role_parameters)
  55. copy_obj.set_output(self.output)
  56. return copy_obj
  57. def set_input(self, inp):
  58. self.input = inp
  59. def get_input(self):
  60. return self.input
  61. def set_output(self, output):
  62. self.output = output
  63. def get_output(self):
  64. return self.output
  65. def get_module(self):
  66. return self.module
  67. def set_component_provider(self, interface):
  68. self.component_provider = interface
  69. def get_component_provider(self):
  70. return self.component_provider
  71. def get_name(self):
  72. return self.name
  73. def get_upstream(self):
  74. return self.upstream
  75. def get_downstream(self):
  76. return self.downstream
  77. def set_name(self, name):
  78. self.name = name
  79. def set_module(self, module):
  80. self.module = module
  81. def set_upstream(self, upstream):
  82. self.upstream = upstream
  83. def set_downstream(self, downstream):
  84. self.downstream = downstream
  85. def set_role_parameters(self, role_parameters):
  86. self.role_parameters = role_parameters
  87. def get_role_parameters(self):
  88. return self.role_parameters
  89. class BaseDSLParser(object):
  90. def __init__(self):
  91. self.dsl = None
  92. self.mode = "train"
  93. self.components = []
  94. self.component_name_index = {}
  95. self.component_upstream = []
  96. self.component_downstream = []
  97. self.train_input_model = {}
  98. self.in_degree = []
  99. self.topo_rank = []
  100. self.predict_dsl = {}
  101. self.runtime_conf = {}
  102. self.pipeline_runtime_conf = {}
  103. self.graph_dependency = None
  104. self.args_input = None
  105. self.args_data_key = None
  106. self.next_component_to_topo = set()
  107. self.job_parameters = {}
  108. self.provider_cache = {}
  109. self.job_providers = {}
  110. self.version = 2
  111. self.local_role = None
  112. self.local_party_id = None
  113. self.predict_runtime_conf = {}
  114. def _init_components(self, mode="train", version=1, **kwargs):
  115. if not self.dsl:
  116. raise DSLNotExistError("")
  117. components = self.dsl.get("components")
  118. if components is None:
  119. raise ComponentFieldNotExistError()
  120. for name in components:
  121. if "module" not in components[name]:
  122. raise ModuleFieldNotExistError(component=name)
  123. module = components[name]["module"]
  124. new_component = Component()
  125. new_component.set_name(name)
  126. new_component.set_module(module)
  127. self.component_name_index[name] = len(self.component_name_index)
  128. self.components.append(new_component)
  129. if version == 2 or mode == "train":
  130. self._check_component_valid_names()
  131. def _check_component_valid_names(self):
  132. for component in self.components:
  133. name = component.get_name()
  134. for chk in name:
  135. if chk.isalpha() or chk in ["_", "-"] or chk.isdigit():
  136. continue
  137. else:
  138. raise NamingFormatError(component=name)
  139. def _find_dependencies(self, mode="train", version=1):
  140. self.component_downstream = [[] for _ in range(len(self.components))]
  141. self.component_upstream = [[] for _ in range(len(self.components))]
  142. components_details = self.dsl.get("components")
  143. components_output = self._find_outputs(self.dsl)
  144. for name in self.component_name_index.keys():
  145. idx = self.component_name_index.get(name)
  146. upstream_input = components_details.get(name).get("input")
  147. downstream_output = components_details.get(name).get("output", {})
  148. self.components[idx].set_output(downstream_output)
  149. if upstream_input is None:
  150. continue
  151. elif not isinstance(upstream_input, dict):
  152. raise ComponentInputTypeError(component=name)
  153. else:
  154. self.components[idx].set_input(upstream_input)
  155. if mode == "train":
  156. input_keywords = {"model": "model", "isometric_model": "model", "cache": "cache"}
  157. else:
  158. input_keywords = {"cache": "cache"}
  159. for keyword, out_type in input_keywords.items():
  160. if keyword in upstream_input:
  161. input_list = upstream_input.get(keyword)
  162. if not isinstance(input_list, list):
  163. raise ComponentInputValueTypeError(component=name, value_type="model",
  164. other_info=input_list)
  165. for _input in input_list:
  166. input_component = _input.split(".", -1)[0]
  167. input_model_name = _input.split(".")[-1]
  168. if input_component not in self.component_name_index:
  169. raise InputComponentNotExistError(component=name, value_type=keyword, input=input_component)
  170. else:
  171. if input_component not in components_output or out_type not in components_output[input_component]:
  172. raise InputNameNotExistError(component=name, input=input_component,
  173. value_type=keyword, other_info=input_model_name)
  174. idx_dependency = self.component_name_index.get(input_component)
  175. self.component_downstream[idx_dependency].append(name)
  176. self.component_upstream[idx].append(input_component)
  177. if keyword == "model" or keyword == "cache":
  178. self.train_input_model[name] = input_component
  179. if "data" in upstream_input:
  180. data_dict = upstream_input.get("data")
  181. if not isinstance(data_dict, dict):
  182. raise ComponentInputDataTypeError(component=name)
  183. for data_set in data_dict:
  184. if not isinstance(data_dict.get(data_set), list):
  185. raise ComponentInputValueTypeError(component=name, value_type="data",
  186. other_info=data_dict.get(data_set))
  187. if version == 2 and data_set not in ["data", "train_data", "validate_data", "test_data",
  188. "eval_data"]:
  189. stat_logger.warning(
  190. "DSLParser Warning: make sure that input data's data key should be in {}, but {} found".format(
  191. ["data", "train_data", "validate_data", "test_data", "eval_data"], data_set))
  192. for data_key in data_dict.get(data_set):
  193. input_component = data_key.split(".", -1)[0]
  194. input_data_name = data_key.split(".", -1)[-1]
  195. if input_component not in self.component_name_index:
  196. raise InputComponentNotExistError(component=name, value_type="data",
  197. input=input_component)
  198. else:
  199. if input_component not in components_output \
  200. or "data" not in components_output[input_component] \
  201. or input_data_name not in components_output[input_component]["data"]:
  202. raise InputNameNotExistError(component=name, input=input_component,
  203. value_type="data", other_info=input_data_name)
  204. idx_dependency = self.component_name_index.get(input_component)
  205. self.component_downstream[idx_dependency].append(name)
  206. self.component_upstream[idx].append(input_component)
  207. self.in_degree = [0 for _ in range(len(self.components))]
  208. for i in range(len(self.components)):
  209. if self.component_downstream[i]:
  210. self.component_downstream[i] = list(set(self.component_downstream[i]))
  211. if self.component_upstream[i]:
  212. self.component_upstream[i] = list(set(self.component_upstream[i]))
  213. self.in_degree[self.component_name_index.get(self.components[i].get_name())] = len(
  214. self.component_upstream[i])
  215. self._check_dag_dependencies()
  216. for i in range(len(self.components)):
  217. self.components[i].set_upstream(self.component_upstream[i])
  218. self.components[i].set_downstream(self.component_downstream[i])
  219. def _init_component_setting(self,
  220. component,
  221. provider_detail,
  222. provider_name,
  223. provider_version,
  224. local_role,
  225. local_party_id,
  226. runtime_conf,
  227. redundant_param_check=True,
  228. parse_user_specified_only=False,
  229. previous_parameters=None
  230. ):
  231. """
  232. init top input
  233. """
  234. provider = RuntimeConfParserUtil.instantiate_component_provider(provider_detail,
  235. provider_name=provider_name,
  236. provider_version=provider_version)
  237. pos = self.component_name_index[component]
  238. module = self.components[pos].get_module()
  239. parent_path = [component]
  240. cur_component = component
  241. isometric_component = None
  242. while True:
  243. if self.train_input_model.get(cur_component, None) is None:
  244. break
  245. else:
  246. is_warm_start = self._is_warm_start(cur_component)
  247. is_same_module = True
  248. input_component = self.train_input_model.get(cur_component)
  249. input_pos = self.component_name_index[input_component]
  250. if self.components[input_pos].get_module() != module:
  251. is_same_module = False
  252. if not is_warm_start and is_same_module:
  253. cur_component = self.train_input_model.get(cur_component)
  254. parent_path.append(cur_component)
  255. else:
  256. if (is_warm_start or not is_same_module) and self.components[input_pos].get_module().lower() == "modelloader":
  257. model_load_alias = RuntimeConfParserUtil.get_model_loader_alias(input_component, runtime_conf,
  258. local_role, local_party_id)
  259. isometric_component = model_load_alias
  260. else:
  261. isometric_component = input_component
  262. break
  263. pre_parameters = {}
  264. if previous_parameters is not None:
  265. if not isometric_component:
  266. pre_parameters = previous_parameters.get(cur_component, {})
  267. else:
  268. pre_parameters = previous_parameters.get(isometric_component, {})
  269. if self.mode == "predict" and pre_parameters:
  270. source_component = previous_parameters.get(component, {}).get(ComponentParameterSource)
  271. if source_component and source_component != cur_component:
  272. runtime_conf = self.runtime_conf
  273. role_parameters = RuntimeConfParserUtil.get_component_parameters(provider,
  274. runtime_conf,
  275. module,
  276. cur_component,
  277. redundant_param_check=redundant_param_check,
  278. local_role=local_role,
  279. local_party_id=local_party_id,
  280. parse_user_specified_only=parse_user_specified_only,
  281. pre_parameters=pre_parameters)
  282. if role_parameters:
  283. role_parameters[ComponentParameterSource] = cur_component
  284. for component in parent_path:
  285. idx = self.component_name_index.get(component)
  286. self.components[idx].set_component_provider(provider)
  287. self.components[idx].set_role_parameters(role_parameters)
  288. return role_parameters
  289. def _is_warm_start(self, component_name):
  290. component_idx = self.component_name_index.get(component_name)
  291. upstream_inputs = self.components[component_idx].get_input()
  292. if not upstream_inputs:
  293. return False
  294. return "train_data" in upstream_inputs.get("data", {}) and "model" in upstream_inputs
  295. def parse_component_parameters(self, component_name, provider_detail, provider_name, provider_version, local_role,
  296. local_party_id, previous_parameters=None):
  297. if self.mode == "predict":
  298. runtime_conf = self.predict_runtime_conf
  299. else:
  300. runtime_conf = self.runtime_conf
  301. redundant_param_check = True
  302. parameters = self._init_component_setting(component_name,
  303. provider_detail,
  304. provider_name,
  305. provider_version,
  306. local_role,
  307. local_party_id,
  308. runtime_conf,
  309. redundant_param_check,
  310. parse_user_specified_only=False,
  311. previous_parameters=previous_parameters)
  312. return parameters
  313. def get_component_info(self, component_name):
  314. if component_name not in self.component_name_index:
  315. raise ComponentNotExistError(component=component_name)
  316. idx = self.component_name_index.get(component_name)
  317. return self.components[idx]
  318. def get_upstream_dependent_components(self, component_name):
  319. dependent_component_names = self.get_component_info(component_name).get_upstream()
  320. dependent_components = []
  321. for up_cpn in dependent_component_names:
  322. up_cpn_idx = self.component_name_index.get(up_cpn)
  323. dependent_components.append(self.components[up_cpn_idx])
  324. return dependent_components
  325. def get_downstream_dependent_components(self, component_name):
  326. component_idx = self.component_name_index.get(component_name)
  327. downstream_components = []
  328. for cpn in self.component_downstream[component_idx]:
  329. down_cpn_idx = self.component_name_index.get(cpn)
  330. downstream_components.append(self.components[down_cpn_idx])
  331. return downstream_components
  332. def get_topology_components(self):
  333. topo_components = []
  334. for i in range(len(self.topo_rank)):
  335. topo_components.append(self.components[self.topo_rank[i]])
  336. return topo_components
  337. @staticmethod
  338. def _find_outputs(dsl):
  339. outputs = {}
  340. components_details = dsl.get("components")
  341. for name in components_details.keys():
  342. if "output" not in components_details.get(name):
  343. continue
  344. component_output = components_details.get(name).get("output")
  345. output_keys = ["data", "model", "cache"]
  346. if not isinstance(component_output, dict):
  347. raise ComponentOutputTypeError(component=name, other_info=component_output)
  348. for key in output_keys:
  349. if key not in component_output:
  350. continue
  351. out_v = component_output.get(key)
  352. if not isinstance(out_v, list):
  353. raise ComponentOutputKeyTypeError(component=name, other_info=key)
  354. if name not in outputs:
  355. outputs[name] = {}
  356. outputs[name][key] = out_v
  357. return outputs
  358. def _check_dag_dependencies(self):
  359. in_degree = copy.deepcopy(self.in_degree)
  360. stack = []
  361. for i in range(len(self.components)):
  362. if in_degree[i] == 0:
  363. stack.append(i)
  364. tot_nodes = 0
  365. while len(stack) > 0:
  366. idx = stack.pop()
  367. tot_nodes += 1
  368. self.topo_rank.append(idx)
  369. for down_name in self.component_downstream[idx]:
  370. down_idx = self.component_name_index.get(down_name)
  371. in_degree[down_idx] -= 1
  372. if in_degree[down_idx] == 0:
  373. stack.append(down_idx)
  374. if tot_nodes != len(self.components):
  375. stack = []
  376. vis = [False for _ in range(len(self.components))]
  377. for i in range(len(self.components)):
  378. if vis[i]:
  379. continue
  380. loops = []
  381. self._find_loop(i, vis, stack, loops)
  382. raise LoopError(loops)
  383. def _find_loop(self, u, vis, stack, loops):
  384. vis[u] = True
  385. stack.append(u)
  386. for down_name in self.component_downstream[u]:
  387. if loops:
  388. return
  389. v = self.component_name_index.get(down_name)
  390. if v not in stack:
  391. if not vis[v]:
  392. self._find_loop(v, vis, stack, loops)
  393. else:
  394. index = stack.index(v)
  395. for node in stack[index:]:
  396. loops.append(self.components[node].get_name())
  397. return
  398. stack.pop(-1)
  399. def prepare_graph_dependency_info(self):
  400. dependence_dict = {}
  401. component_module = {}
  402. for component in self.components:
  403. name = component.get_name()
  404. module = component.get_module()
  405. component_module[name] = module
  406. if not component.get_input():
  407. continue
  408. dependence_dict[name] = []
  409. inputs = component.get_input()
  410. if "data" in inputs:
  411. data_input = inputs["data"]
  412. for data_key, data_list in data_input.items():
  413. for dataset in data_list:
  414. up_component_name = dataset.split(".", -1)[0]
  415. up_pos = self.component_name_index.get(up_component_name)
  416. up_component = self.components[up_pos]
  417. data_name = dataset.split(".", -1)[1]
  418. if up_component.get_output().get("data"):
  419. data_pos = up_component.get_output().get("data").index(data_name)
  420. else:
  421. data_pos = 0
  422. if data_key == "data" or data_key == "train_data":
  423. data_type = data_key
  424. else:
  425. data_type = "validate_data"
  426. dependence_dict[name].append({"component_name": up_component_name,
  427. "type": data_type,
  428. "up_output_info": ["data", data_pos]})
  429. input_keyword_type_mapping = {"model": "model",
  430. "isometric_model": "model",
  431. "cache": "cache"}
  432. for keyword, v_type in input_keyword_type_mapping.items():
  433. if keyword in inputs:
  434. input_list = inputs[keyword]
  435. for _input in input_list:
  436. up_component_name = _input.split(".", -1)[0]
  437. if up_component_name == "pipeline":
  438. continue
  439. link_alias = _input.split(".", -1)[1]
  440. up_pos = self.component_name_index.get(up_component_name)
  441. up_component = self.components[up_pos]
  442. if up_component.get_output().get(v_type):
  443. dep_pos = up_component.get_output().get(v_type).index(link_alias)
  444. else:
  445. dep_pos = 0
  446. dependence_dict[name].append({"component_name": up_component_name,
  447. "type": v_type,
  448. "up_output_info": [v_type, dep_pos]})
  449. if not dependence_dict[name]:
  450. del dependence_dict[name]
  451. component_list = [None for _ in range(len(self.components))]
  452. topo_rank_reverse_mapping = {}
  453. for i in range(len(self.topo_rank)):
  454. topo_rank_reverse_mapping[self.topo_rank[i]] = i
  455. for key, value in self.component_name_index.items():
  456. topo_rank_idx = topo_rank_reverse_mapping[value]
  457. component_list[topo_rank_idx] = key
  458. base_dependency = {"component_list": component_list,
  459. "dependencies": dependence_dict,
  460. "component_module": component_module,
  461. "component_need_run": {}}
  462. self.graph_dependency = base_dependency
  463. def get_dsl_hierarchical_structure(self):
  464. max_depth = [0] * len(self.components)
  465. for idx in range(len(self.topo_rank)):
  466. vertex = self.topo_rank[idx]
  467. for down_name in self.component_downstream[vertex]:
  468. down_vertex = self.component_name_index.get(down_name)
  469. max_depth[down_vertex] = max(max_depth[down_vertex], max_depth[vertex] + 1)
  470. max_dep = max(max_depth)
  471. hierarchical_structure = [[] for _ in range(max_dep + 1)]
  472. name_component_maps = {}
  473. for component in self.components:
  474. name = component.get_name()
  475. vertex = self.component_name_index.get(name)
  476. hierarchical_structure[max_depth[vertex]].append(name)
  477. name_component_maps[name] = component
  478. return name_component_maps, hierarchical_structure
  479. def get_dependency(self):
  480. return self.graph_dependency
  481. def get_dependency_with_parameters(self, component_parameters):
  482. return self.extract_need_run_status(self.graph_dependency, component_parameters)
  483. def extract_need_run_status(self, graph_dependency, component_parameters):
  484. for rank in range(len(self.topo_rank)):
  485. idx = self.topo_rank[rank]
  486. name = self.components[idx].get_name()
  487. parameters = component_parameters.get(name)
  488. if not parameters:
  489. graph_dependency["component_need_run"][name] = False
  490. else:
  491. if self.train_input_model.get(name, None) is None:
  492. param_name = "ComponentParam"
  493. if parameters.get(param_name) is None \
  494. or parameters[param_name].get("need_run") is False:
  495. graph_dependency["component_need_run"][name] = False
  496. else:
  497. graph_dependency["component_need_run"][name] = True
  498. else:
  499. input_model_name = self.train_input_model.get(name)
  500. graph_dependency["component_need_run"][name] = graph_dependency["component_need_run"][
  501. input_model_name]
  502. return graph_dependency
  503. @staticmethod
  504. def verify_dsl(dsl, mode="train"):
  505. dsl_parser = DSLParserV2()
  506. dsl_parser.dsl = dsl
  507. dsl_parser._init_components(mode=mode, version=2)
  508. dsl_parser._find_dependencies(mode=mode, version=2)
  509. @staticmethod
  510. def verify_dsl_reusability(reused_dsl, new_dsl, reused_components):
  511. # step 1, verify new dsl
  512. dsl_parser = DSLParserV2()
  513. dsl_parser.dsl = new_dsl
  514. dsl_parser._init_components(mode="train", version=2)
  515. dsl_parser._find_dependencies(mode="train", version=2)
  516. # step 2, verify reused components is a sub-graph
  517. reused_components = set(reused_components)
  518. # reused_components = set(reused_dsl["components"]) & set(new_dsl["components"])
  519. for cpn in reused_components:
  520. validate_key = ["input", "output", "provider"]
  521. for vk in validate_key:
  522. config_old = reused_dsl["components"][cpn].get(vk, None)
  523. config_new = new_dsl["components"][cpn].get(vk, None)
  524. if config_old != config_new:
  525. raise ValueError(f"Component {cpn}'s {vk} should be same, but old is {config_old}, new is {config_new}")
  526. inputs = reused_dsl["components"][cpn].get("input", {})
  527. list_dep_key = ["cache", "model", "isometric_model"]
  528. for dep_key in list_dep_key:
  529. dep_list = inputs.get(dep_key, [])
  530. for dep in dep_list:
  531. input_dep = dep.split(".", -1)[0]
  532. if input_dep not in reused_components:
  533. raise ValueError(f"Component {cpn}'s {dep_key} input {input_dep} should be reused")
  534. data_dep = inputs.get("data", {})
  535. for data_key, data_list in data_dep.items():
  536. for dep in data_list:
  537. input_dep = dep.split(".", -1)[0]
  538. if input_dep not in reused_components:
  539. raise ValueError(f"Component {cpn}'s {data_key} input {input_dep} should be reused")
  540. @staticmethod
  541. def deploy_component(components, train_dsl, provider_update_dsl=None):
  542. training_cpns = set(train_dsl.get("components").keys())
  543. deploy_cpns = set(components)
  544. if len(deploy_cpns & training_cpns) != len(deploy_cpns):
  545. raise DeployComponentNotExistError(msg=deploy_cpns - training_cpns)
  546. dsl_parser = DSLParserV2()
  547. dsl_parser.dsl = train_dsl
  548. dsl_parser._init_components()
  549. dsl_parser._find_dependencies(version=2)
  550. dsl_parser._auto_deduction(deploy_cpns=deploy_cpns, version=2, erase_top_data_input=True)
  551. """
  552. dsl_parser.update_predict_dsl_provider(train_dsl)
  553. if provider_update_dsl:
  554. dsl_parser.update_predict_dsl_provider(provider_update_dsl)
  555. """
  556. return dsl_parser.predict_dsl
  557. """
  558. def update_predict_dsl_provider(self, dsl):
  559. for component in dsl["components"]:
  560. provider = dsl["components"][component].get("provider")
  561. if provider and component in self.predict_dsl["components"]:
  562. self.predict_dsl["components"][component]["provider"] = provider
  563. if "provider" in dsl:
  564. self.predict_dsl["provider"] = dsl["provider"]
  565. """
  566. def _auto_deduction(self, deploy_cpns=None, version=1, erase_top_data_input=False):
  567. self.predict_dsl = {"components": {}}
  568. self.predict_components = []
  569. mapping_list = {}
  570. for i in range(len(self.topo_rank)):
  571. self.predict_components.append(self.components[self.topo_rank[i]].copy())
  572. mapping_list[self.predict_components[-1].get_name()] = i
  573. output_data_maps = {}
  574. for i in range(len(self.predict_components)):
  575. name = self.predict_components[i].get_name()
  576. module = self.predict_components[i].get_module()
  577. if module == "Reader":
  578. if version != 2:
  579. raise ValueError("Reader component can only be set in dsl_version 2")
  580. if self.get_need_deploy_parameter(name=name, deploy_cpns=deploy_cpns):
  581. self.predict_dsl["components"][name] = {"module": self.predict_components[i].get_module()}
  582. """replace output model to pipeline"""
  583. if "output" in self.dsl["components"][name]:
  584. model_list = self.dsl["components"][name]["output"].get("model", None)
  585. if model_list is not None:
  586. if "input" not in self.predict_dsl["components"][name]:
  587. self.predict_dsl["components"][name]["input"] = {}
  588. replace_model = [".".join(["pipeline", name, model]) for model in model_list]
  589. self.predict_dsl["components"][name]["input"]["model"] = replace_model
  590. for out_key, out_val in self.dsl["components"][name]["output"].items():
  591. if out_val is not None and out_key != "model":
  592. if "output" not in self.predict_dsl["components"][name]:
  593. self.predict_dsl["components"][name]["output"] = {}
  594. self.predict_dsl["components"][name]["output"][out_key] = out_val
  595. if "input" in self.dsl["components"][name]:
  596. if "input" not in self.predict_dsl["components"][name]:
  597. self.predict_dsl["components"][name]["input"] = {}
  598. if "data" in self.dsl["components"][name]["input"]:
  599. self.predict_dsl["components"][name]["input"]["data"] = {}
  600. for data_key, data_value in self._gen_predict_data_mapping():
  601. if data_key not in self.dsl["components"][name]["input"]["data"]:
  602. continue
  603. data_set = self.dsl["components"][name]["input"]["data"].get(data_key)
  604. self.predict_dsl["components"][name]["input"]["data"][data_value] = []
  605. for input_data in data_set:
  606. if version == 1 and input_data.split(".")[0] == "args":
  607. new_input_data = "args.eval_data"
  608. self.predict_dsl["components"][name]["input"]["data"][data_value].append(new_input_data)
  609. elif version == 2 and input_data.split(".")[0] == "args":
  610. self.predict_dsl["components"][name]["input"]["data"][data_value].append(input_data)
  611. elif version == 2 and self.dsl["components"][input_data.split(".")[0]].get("module") == "Reader":
  612. self.predict_dsl["components"][name]["input"]["data"][data_value].append(input_data)
  613. else:
  614. pre_name = input_data.split(".")[0]
  615. data_suffix = input_data.split(".")[1]
  616. if self.get_need_deploy_parameter(name=pre_name, deploy_cpns=deploy_cpns):
  617. self.predict_dsl["components"][name]["input"]["data"][data_value].append(input_data)
  618. else:
  619. self.predict_dsl["components"][name]["input"]["data"][data_value].extend(
  620. output_data_maps[pre_name][data_suffix])
  621. break
  622. if "cache" in self.dsl["components"][name]["input"]:
  623. cache_set = self.dsl["components"][name]["input"]["cache"]
  624. self.predict_dsl["components"][name]["input"]["cache"] = []
  625. for input_cache in cache_set:
  626. pre_name, cache_suffix = input_cache.split(".")[:2]
  627. input_deploy = self.get_need_deploy_parameter(name=pre_name, deploy_cpns=deploy_cpns)
  628. if version == 1 and not input_deploy:
  629. raise ValueError("In dsl v1, if cache is enabled, input component should be deploy")
  630. self.predict_dsl["components"][name]["input"]["cache"].append(input_cache)
  631. if version == 2 and erase_top_data_input:
  632. input_dep = {}
  633. for data_key, data_set in self.predict_dsl["components"][name]["input"]["data"].items():
  634. final_data_set = []
  635. for input_data in data_set:
  636. cpn_alias = input_data.split(".")[0]
  637. if cpn_alias in self.predict_dsl["components"]:
  638. final_data_set.append(input_data)
  639. if final_data_set:
  640. input_dep[data_key] = final_data_set
  641. if not input_dep:
  642. del self.predict_dsl["components"][name]["input"]["data"]
  643. else:
  644. self.predict_dsl["components"][name]["input"]["data"] = input_dep
  645. else:
  646. name = self.predict_components[i].get_name()
  647. input_data, output_data = None, None
  648. if "input" in self.dsl["components"][name] and "data" in self.dsl["components"][name]["input"]:
  649. input_data = self.dsl["components"][name]["input"].get("data")
  650. if "output" in self.dsl["components"][name] and "data" in self.dsl["components"][name]["output"]:
  651. output_data = self.dsl["components"][name]["output"].get("data")
  652. if output_data is None or input_data is None:
  653. continue
  654. output_data_maps[name] = {}
  655. for output_data_str in output_data:
  656. if "train_data" in input_data or "eval_data" in input_data or "test_data" in input_data:
  657. if "train_data" in input_data:
  658. up_input_data = input_data.get("train_data")[0]
  659. elif "eval_data" in input_data:
  660. up_input_data = input_data.get("eval_data")[0]
  661. else:
  662. up_input_data = input_data.get("test_data")[0]
  663. elif "data" in input_data:
  664. up_input_data = input_data.get("data")[0]
  665. else:
  666. raise ValueError("train data or eval data or validate data or data should be set")
  667. up_input_data_component_name = up_input_data.split(".", -1)[0]
  668. if up_input_data_component_name == "args" \
  669. or self.get_need_deploy_parameter(name=up_input_data_component_name, deploy_cpns=deploy_cpns):
  670. output_data_maps[name][output_data_str] = [up_input_data]
  671. elif self.components[self.component_name_index.get(up_input_data_component_name)].get_module() == "Reader":
  672. output_data_maps[name][output_data_str] = [up_input_data]
  673. else:
  674. up_input_data_suf = up_input_data.split(".", -1)[-1]
  675. output_data_maps[name][output_data_str] = output_data_maps[up_input_data_component_name][up_input_data_suf]
  676. def run(self, *args, **kwargs):
  677. pass
  678. def get_runtime_conf(self):
  679. return self.runtime_conf
  680. def get_dsl(self):
  681. return self.dsl
  682. def get_args_input(self):
  683. return self.args_input
  684. @staticmethod
  685. def get_need_deploy_parameter(name, deploy_cpns=None):
  686. if deploy_cpns is not None:
  687. return name in deploy_cpns
  688. return False
  689. def get_job_parameters(self, *args, **kwargs):
  690. return self.job_parameters
  691. def get_job_providers(self, provider_detail=None, dsl=None, conf=None, local_role=None, local_party_id=None):
  692. if dsl is None:
  693. self.job_providers = RuntimeConfParserUtil.get_job_providers(self.dsl, provider_detail, conf,
  694. local_role, local_party_id)
  695. else:
  696. self.job_providers = RuntimeConfParserUtil.get_job_providers(dsl, provider_detail, conf,
  697. local_role, local_party_id)
  698. return self.job_providers
  699. @staticmethod
  700. def _gen_predict_data_mapping():
  701. data_mapping = [("data", "data"), ("train_data", "test_data"),
  702. ("validate_data", "test_data"), ("test_data", "test_data")]
  703. for data_key, data_value in data_mapping:
  704. yield data_key, data_value
  705. @staticmethod
  706. def generate_predict_conf_template(predict_dsl, train_conf, model_id, model_version):
  707. return RuntimeConfParserUtil.generate_predict_conf_template(predict_dsl,
  708. train_conf,
  709. model_id,
  710. model_version)
  711. @staticmethod
  712. def get_predict_dsl(predict_dsl=None, module_object_dict=None):
  713. if not predict_dsl:
  714. return {}
  715. role_predict_dsl = copy.deepcopy(predict_dsl)
  716. component_list = list(predict_dsl.get("components").keys())
  717. for component in component_list:
  718. module_object = module_object_dict.get(component)
  719. if module_object:
  720. role_predict_dsl["components"][component]["CodePath"] = module_object
  721. return role_predict_dsl
  722. @staticmethod
  723. def get_module_object_name(module, local_role, provider_detail,
  724. provider_name, provider_version):
  725. if not provider_detail:
  726. raise ValueError("Component Providers should be provided")
  727. provider = RuntimeConfParserUtil.instantiate_component_provider(provider_detail,
  728. provider_name=provider_name,
  729. provider_version=provider_version)
  730. module_obj_name = RuntimeConfParserUtil.get_module_name(role=local_role,
  731. module=module,
  732. provider=provider)
  733. return module_obj_name
  734. @staticmethod
  735. def validate_component_param(component, module, runtime_conf,
  736. provider_name, provider_version, provider_detail,
  737. local_role, local_party_id):
  738. provider = RuntimeConfParserUtil.instantiate_component_provider(provider_detail,
  739. provider_name=provider_name,
  740. provider_version=provider_version)
  741. try:
  742. RuntimeConfParserUtil.get_component_parameters(provider,
  743. runtime_conf,
  744. module,
  745. component,
  746. redundant_param_check=True,
  747. local_role=local_role,
  748. local_party_id=local_party_id,
  749. parse_user_specified_only=False)
  750. return 0
  751. except Exception as e:
  752. raise ValueError(f"{e}")
  753. @classmethod
  754. def check_input_existence(cls, dsl):
  755. component_details = dsl.get("components", {})
  756. component_outputs = cls._find_outputs(dsl)
  757. input_key = ["data", "model", "isometric_model", "cache"]
  758. non_existence = dict()
  759. for cpn, cpn_detail in component_details.items():
  760. for k in input_key:
  761. input_deps = cpn_detail.get("input", {}).get(k, {})
  762. if not input_deps:
  763. continue
  764. input_splits = None
  765. if k == "data":
  766. for data_k, dep_list in input_deps.items():
  767. for dep in dep_list:
  768. input_splits = dep.split(".", -1)
  769. else:
  770. for dep in input_deps:
  771. input_splits = dep.split(".", -1)
  772. if input_splits[0] == "pipeline":
  773. input_splits = input_splits[1:]
  774. up_cpn, up_link = input_splits
  775. if not component_outputs.get(up_cpn, {}).get(up_link, {}):
  776. if k not in non_existence:
  777. non_existence[k] = list()
  778. non_existence[k].append(f"{cpn}'s {up_cpn}.{up_link}")
  779. if non_existence:
  780. ret_msg = "non exist input:"
  781. for k, v in non_existence.items():
  782. ret_msg += f"\n {k}: " + ",".join(v)
  783. return ret_msg
  784. else:
  785. return ""
  786. class DSLParserV1(BaseDSLParser):
  787. def __init__(self):
  788. super(DSLParserV1, self).__init__()
  789. self.version = 1
  790. @staticmethod
  791. def get_job_parameters(runtime_conf):
  792. job_parameters = RuntimeConfParserUtil.get_job_parameters(runtime_conf,
  793. conf_version=1)
  794. return job_parameters
  795. @staticmethod
  796. def parse_component_role_parameters(component, dsl, runtime_conf, provider_detail, provider_name,
  797. provider_version):
  798. provider = RuntimeConfParserUtil.instantiate_component_provider(provider_detail,
  799. provider_name=provider_name,
  800. provider_version=provider_version)
  801. role_parameters = RuntimeConfParserUtil.get_v1_role_parameters(provider,
  802. component,
  803. runtime_conf,
  804. dsl)
  805. return role_parameters
  806. @staticmethod
  807. def convert_dsl_v1_to_v2(dsl):
  808. dsl_v2 = copy.deepcopy(dsl)
  809. # change dsl v1 to dsl v2
  810. readers = {}
  811. ret_msg = []
  812. for cpn, cpn_detail in dsl["components"].items():
  813. new_cpn_detail = copy.deepcopy(cpn_detail)
  814. if cpn_detail.get("input", {}).get("data", {}):
  815. for data_key, dataset in cpn_detail["input"]["data"].items():
  816. new_dataset = []
  817. for data in dataset:
  818. up_cpn, up_out_alias = data.split(".", -1)
  819. if up_cpn == "args":
  820. if up_out_alias not in readers:
  821. readers[up_out_alias] = "_".join(["reader", str(len(readers))])
  822. ret_msg.append(f"{data} is changed to {readers[up_out_alias]}.{up_out_alias}, please "
  823. f"set input data of {readers[up_out_alias]}")
  824. up_link = ".".join([readers[up_out_alias], up_out_alias])
  825. new_dataset.append(up_link)
  826. else:
  827. new_dataset.append(data)
  828. new_cpn_detail["input"]["data"][data_key] = new_dataset
  829. dsl_v2["components"][cpn] = new_cpn_detail
  830. for output_alias, cpn in readers.items():
  831. reader_detail = dict(module="Reader",
  832. output={"data": [output_alias]},
  833. CodePath="Reader")
  834. dsl_v2["components"].update({cpn: reader_detail})
  835. return dsl_v2, ", ".join(ret_msg)
  836. @staticmethod
  837. def convert_conf_v1_to_v2(conf_v1, role_parameters):
  838. conf_v2 = dict()
  839. for attr, conf in conf_v1.items():
  840. if attr in ["algorithm_parameters", "role_parameters", "job_parameters"]:
  841. continue
  842. conf_v2[attr] = conf
  843. job_params = conf_v1.get("job_parameters", {})
  844. conf_v2["job_parameters"] = dict(common=job_params)
  845. algorithm_params = conf_v1.get("algorithm_parameters", {})
  846. if algorithm_params or conf_v1.get("role_parameters"):
  847. conf_v2["component_parameters"] = dict()
  848. if algorithm_params:
  849. conf_v2["component_parameters"]["common"] = algorithm_params
  850. if conf_v1.get("role_parameters"):
  851. conf_v2["component_parameters"]["role"] = dict()
  852. for cpn, role_params in role_parameters.items():
  853. conf_v2["component_parameters"]["role"] = RuntimeConfParserUtil.merge_dict(conf_v2["component_parameters"]["role"],
  854. role_params)
  855. conf_v2["dsl_version"] = 2
  856. return conf_v2
  857. """
  858. @staticmethod
  859. def change_conf_v1_to_v2(dsl_v2, conf_v1, provider_detail):
  860. # change conf v1 to conf v2
  861. readers = dict()
  862. for cpn, cpn_detail in dsl_v2["components"].items():
  863. if cpn_detail.get("module") != "Reader":
  864. continue
  865. output_alias = cpn_detail["output"]["data"]
  866. readers[output_alias] = cpn
  867. conf_v2 = RuntimeConfParserUtil.change_conf_v1_to_v2(dsl_v2, conf_v1, readers, provider_detail)
  868. return conf_v2
  869. """
  870. @staticmethod
  871. def get_components_light_weight(dsl_v2):
  872. components = []
  873. for cpn, cpn_detail in dsl_v2["components"].items():
  874. component = Component()
  875. component.set_name(cpn)
  876. component.set_module(cpn_detail["module"])
  877. components.append(component)
  878. return components
  879. class DSLParserV2(BaseDSLParser):
  880. def __init__(self):
  881. super(DSLParserV2, self).__init__()
  882. self.version = 2
  883. def run(self, pipeline_runtime_conf=None, dsl=None, runtime_conf=None,
  884. provider_detail=None, mode="train",
  885. local_role=None, local_party_id=None, *args, **kwargs):
  886. if mode not in ["train", "predict"]:
  887. raise ModeError("")
  888. self.dsl = copy.deepcopy(dsl)
  889. self._init_components(mode, version=2)
  890. self._find_dependencies(mode, version=2)
  891. self.runtime_conf = runtime_conf
  892. self.pipeline_runtime_conf = pipeline_runtime_conf
  893. self.mode = mode
  894. self.local_role = local_role
  895. self.local_party_id = local_party_id
  896. if mode == "train":
  897. self.job_parameters = RuntimeConfParserUtil.get_job_parameters(self.runtime_conf,
  898. conf_version=2)
  899. else:
  900. """training provider will be delete first"""
  901. pipeline_runtime_conf = copy.deepcopy(pipeline_runtime_conf)
  902. if "provider" in pipeline_runtime_conf:
  903. del pipeline_runtime_conf["provider"]
  904. predict_runtime_conf = RuntimeConfParserUtil.merge_predict_runtime_conf(pipeline_runtime_conf,
  905. runtime_conf)
  906. self.predict_runtime_conf = predict_runtime_conf
  907. self.job_parameters = RuntimeConfParserUtil.get_job_parameters(predict_runtime_conf,
  908. conf_version=2)
  909. self.args_input = RuntimeConfParserUtil.get_input_parameters(runtime_conf,
  910. components=self._get_reader_components())
  911. self.prepare_graph_dependency_info()
  912. return self.components
  913. def parse_user_specified_component_parameters(self, component_name, provider_detail, provider_name,
  914. provider_version, local_role, local_party_id, previous_parameters=None):
  915. if self.mode == "predict":
  916. runtime_conf = self.predict_runtime_conf
  917. else:
  918. runtime_conf = self.runtime_conf
  919. parameters = self._init_component_setting(component_name,
  920. provider_detail,
  921. provider_name,
  922. provider_version,
  923. local_role,
  924. local_party_id,
  925. runtime_conf,
  926. redundant_param_check=False,
  927. parse_user_specified_only=True,
  928. previous_parameters=previous_parameters)
  929. return parameters
  930. def _get_reader_components(self):
  931. reader_components = []
  932. for cpn, conf in self.dsl.get("components").items():
  933. if conf.get("module") == "Reader":
  934. reader_components.append(cpn)
  935. return reader_components
  936. def get_source_connect_sub_graph(self, valid_nodes):
  937. invalid_nodes = set([self.components[i].get_name() for i in range(len(self.components))]) - set(valid_nodes)
  938. return self._get_source_connect_nodes(invalid_nodes)
  939. def get_need_revisit_nodes(self, visited_nodes, failed_nodes):
  940. invalid_nodes = set([self.components[i].get_name() for i in range(len(self.components))]) - set(visited_nodes)
  941. invalid_nodes |= set(failed_nodes)
  942. connected_nodes = self._get_source_connect_nodes(invalid_nodes)
  943. connected_nodes_name = [node.get_name() for node in connected_nodes]
  944. revisit_nodes = []
  945. for node in visited_nodes:
  946. if node not in connected_nodes_name:
  947. idx = self.component_name_index[node]
  948. revisit_nodes.append(self.components[idx])
  949. return revisit_nodes
  950. def _get_source_connect_nodes(self, invalid_nodes):
  951. in_degree = copy.deepcopy(self.in_degree)
  952. stack = []
  953. for i in range(len(self.components)):
  954. if self.components[i].get_name() in invalid_nodes:
  955. continue
  956. if in_degree[i] == 0:
  957. stack.append(i)
  958. connected_nodes = []
  959. while len(stack) > 0:
  960. idx = stack.pop()
  961. connected_nodes.append(self.components[idx])
  962. for down_name in self.component_downstream[idx]:
  963. if down_name in invalid_nodes:
  964. continue
  965. down_idx = self.component_name_index.get(down_name)
  966. in_degree[down_idx] -= 1
  967. if in_degree[down_idx] == 0:
  968. stack.append(down_idx)
  969. return connected_nodes
  970. @staticmethod
  971. def verify_conf_reusability(reused_conf, new_conf, reused_components):
  972. reused_components = set(reused_components)
  973. # step1: check role, it should be same
  974. # reused_conf_role = reused_conf.get("role", {})
  975. # new_conf_role = new_conf.get("role", {})
  976. # if reused_conf_role != new_conf_role:
  977. # raise ValueError(f"role {reused_conf_role} does not equals to {new_conf_role}")
  978. # step2: check component common parameters
  979. pre_component_parameters = reused_conf.get("component_parameters", {})
  980. cur_component_parameters = new_conf.get("component_parameters", {})
  981. pre_common_params = pre_component_parameters.get("common", {})
  982. cur_common_params = cur_component_parameters.get("common", {})
  983. pre_role_params = pre_component_parameters.get("role", {})
  984. cur_role_params = cur_component_parameters.get("role", {})
  985. for cpn in reused_components:
  986. cpn_pre_common_params = pre_common_params.get(cpn, {})
  987. cpn_cur_common_params = cur_common_params.get(cpn, {})
  988. if cpn_pre_common_params != cpn_cur_common_params:
  989. raise ValueError(f"{cpn}'s common parameters old:{cpn_pre_common_params} != new:{cpn_cur_common_params}")
  990. # step3: check component role parameters
  991. first_role_params = pre_role_params
  992. second_role_params = cur_role_params
  993. for idx in range(2):
  994. for r, role_params in first_role_params.items():
  995. for party_idx, params in role_params.items():
  996. for cpn in reused_components:
  997. cpn_first_role_params = params.get(cpn)
  998. if not cpn_first_role_params:
  999. continue
  1000. cpn_second_role_params = second_role_params.get(r, {}).get(party_idx, {}).get(cpn)
  1001. if cpn_first_role_params != cpn_second_role_params:
  1002. if idx == 1:
  1003. cpn_first_role_params, cpn_second_role_params = cpn_second_role_params, cpn_first_role_params
  1004. raise ValueError(f"{cpn}'s role parameters old:{r}-{party_idx}-{cpn_first_role_params} "
  1005. f"!= new: {r}-{party_idx}-{cpn_second_role_params}")
  1006. first_role_params, second_role_params = cur_role_params, pre_role_params