pipeline.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import getpass
  18. import json
  19. import pickle
  20. import time
  21. from types import SimpleNamespace
  22. from pipeline.backend.config import Role
  23. from pipeline.backend.config import StatusCode
  24. from pipeline.backend.config import VERSION
  25. from pipeline.backend.config import PipelineConfig
  26. from pipeline.backend._operation import OnlineCommand, ModelConvert
  27. from pipeline.backend.task_info import TaskInfo
  28. from pipeline.component.component_base import Component
  29. from pipeline.component.reader import Reader
  30. from pipeline.interface import Data
  31. from pipeline.interface import Model
  32. from pipeline.interface import Cache
  33. from pipeline.utils import tools
  34. from pipeline.utils.invoker.job_submitter import JobInvoker
  35. from pipeline.utils.logger import LOGGER
  36. from pipeline.runtime.entity import JobParameters
  37. class PipeLine(object):
  38. def __init__(self):
  39. self._create_time = time.asctime(time.localtime(time.time()))
  40. self._initiator = None
  41. self._roles = {}
  42. self._components = {}
  43. self._components_input = {}
  44. self._train_dsl = {}
  45. self._predict_dsl = {}
  46. self._train_conf = {}
  47. self._predict_conf = {}
  48. self._upload_conf = []
  49. self._cur_state = None
  50. self._job_invoker = JobInvoker()
  51. self._train_job_id = None
  52. self._predict_job_id = None
  53. self._fit_status = None
  54. self._train_board_url = None
  55. self._model_info = None
  56. self._predict_model_info = None
  57. self._train_components = {}
  58. self._stage = "fit"
  59. self._data_to_feed_in_prediction = None
  60. self._predict_pipeline = []
  61. self._deploy = False
  62. self._system_role = PipelineConfig.SYSTEM_SETTING.get("role")
  63. self.online = OnlineCommand(self)
  64. self._load = False
  65. self.model_convert = ModelConvert(self)
  66. self._global_job_provider = None
  67. @LOGGER.catch(reraise=True)
  68. def set_initiator(self, role, party_id):
  69. self._initiator = SimpleNamespace(role=role, party_id=party_id)
  70. # for predict pipeline
  71. if self._predict_pipeline:
  72. predict_pipeline = self._predict_pipeline[0]["pipeline"]
  73. predict_pipeline._initiator = SimpleNamespace(role=role, party_id=party_id)
  74. return self
  75. def get_component_list(self):
  76. return copy.copy(list(self._components.keys()))
  77. def restore_roles(self, initiator, roles):
  78. self._initiator = initiator
  79. self._roles = roles
  80. @LOGGER.catch(reraise=True)
  81. def get_predict_meta(self):
  82. if self._fit_status != StatusCode.SUCCESS:
  83. raise ValueError("To get predict meta, please fit successfully")
  84. return {"predict_dsl": self._predict_dsl,
  85. "train_conf": self._train_conf,
  86. "initiator": self._initiator,
  87. "roles": self._roles,
  88. "model_info": self._model_info,
  89. "components": self._components,
  90. "stage": self._stage
  91. }
  92. def get_predict_model_info(self):
  93. return copy.deepcopy(self._predict_model_info)
  94. def get_model_info(self):
  95. return copy.deepcopy(self._model_info)
  96. def get_train_dsl(self):
  97. return copy.deepcopy(self._train_dsl)
  98. def get_train_conf(self):
  99. return copy.deepcopy(self._train_conf)
  100. def get_predict_dsl(self):
  101. return copy.deepcopy(self._predict_dsl)
  102. def get_predict_conf(self):
  103. return copy.deepcopy(self._predict_conf)
  104. def get_upload_conf(self):
  105. return copy.deepcopy(self._upload_conf)
  106. def _get_initiator_conf(self):
  107. if self._initiator is None:
  108. raise ValueError("Please set initiator of PipeLine")
  109. initiator_conf = {"role": self._initiator.role,
  110. "party_id": self._initiator.party_id}
  111. return initiator_conf
  112. def set_global_job_provider(self, provider):
  113. self._global_job_provider = provider
  114. return self
  115. @LOGGER.catch(reraise=True)
  116. def set_roles(self, guest=None, host=None, arbiter=None, **kwargs):
  117. local_parameters = locals()
  118. support_roles = Role.support_roles()
  119. for role, party_id in local_parameters.items():
  120. if role == "self":
  121. continue
  122. if not local_parameters.get(role):
  123. continue
  124. if role not in support_roles:
  125. raise ValueError("Current role not support {}, support role list {}".format(role, support_roles))
  126. party_id = local_parameters.get(role)
  127. self._roles[role] = []
  128. if isinstance(party_id, int):
  129. self._roles[role].append(party_id)
  130. elif isinstance(party_id, list):
  131. self._roles[role].extend(party_id)
  132. else:
  133. raise ValueError("role: {}'s party_id should be an integer or a list of integer".format(role))
  134. # update role config for compiled pipeline
  135. if self._train_conf:
  136. if role in self._train_conf["role"]:
  137. self._train_conf["role"][role] = self._roles[role]
  138. if self._predict_pipeline:
  139. predict_pipeline = self._predict_pipeline[0]["pipeline"]
  140. predict_pipeline._roles = self._roles
  141. return self
  142. def _get_role_conf(self):
  143. return self._roles
  144. def _get_party_index(self, role, party_id):
  145. if role not in self._roles:
  146. raise ValueError("role {} does not setting".format(role))
  147. if party_id not in self._roles[role]:
  148. raise ValueError("role {} does not init setting with the party_id {}".format(role, party_id))
  149. return self._roles[role].index(party_id)
  150. @LOGGER.catch(reraise=True)
  151. def add_component(self, component, data=None, model=None, cache=None):
  152. if isinstance(component, PipeLine):
  153. if component.is_deploy() is False:
  154. raise ValueError("To use a training pipeline object as predict component, should deploy model first")
  155. if model:
  156. raise ValueError("pipeline should not have model as input!")
  157. if not data:
  158. raise ValueError("To use pipeline as a component, please set data input")
  159. self._stage = "predict"
  160. self._predict_pipeline.append({"pipeline": component, "data": data.predict_input})
  161. meta = component.get_predict_meta()
  162. self.restore_roles(meta.get("initiator"), meta.get("roles"))
  163. return self
  164. if not isinstance(component, Component):
  165. raise ValueError(
  166. "To add a component to pipeline, component {} should be a Component object".format(component))
  167. if component.name in self._components:
  168. raise Warning("component {} is added before".format(component.name))
  169. self._components[component.name] = component
  170. if data is not None:
  171. if not isinstance(data, Data):
  172. raise ValueError("data input of component {} should be passed by data object".format(component.name))
  173. attrs_dict = vars(data)
  174. self._components_input[component.name] = {"data": {}}
  175. for attr, val in attrs_dict.items():
  176. if not attr.endswith("data"):
  177. continue
  178. if val is None:
  179. continue
  180. data_key = attr.strip("_")
  181. if isinstance(val, list):
  182. self._components_input[component.name]["data"][data_key] = val
  183. else:
  184. self._components_input[component.name]["data"][data_key] = [val]
  185. if model is not None:
  186. if not isinstance(model, Model):
  187. raise ValueError("model input of component {} should be passed by model object".format(component.name))
  188. attrs_dict = vars(model)
  189. for attr, val in attrs_dict.items():
  190. if not attr.endswith("model"):
  191. continue
  192. if val is None:
  193. continue
  194. if isinstance(val, list):
  195. self._components_input[component.name][attr.strip("_")] = val
  196. else:
  197. self._components_input[component.name][attr.strip("_")] = [val]
  198. if cache is not None:
  199. if not isinstance(cache, Cache):
  200. raise ValueError("cache input of component {} should be passed by cache object".format(component.name))
  201. attr = cache.cache
  202. if not isinstance(attr, list):
  203. attr = [attr]
  204. self._components_input[component.name]["cache"] = attr
  205. return self
  206. @LOGGER.catch(reraise=True)
  207. def add_upload_data(self, file, table_name, namespace, head=1, partition=16,
  208. id_delimiter=",", extend_sid=False, auto_increasing_sid=False, **kargs):
  209. data_conf = {"file": file,
  210. "table_name": table_name,
  211. "namespace": namespace,
  212. "head": head,
  213. "partition": partition,
  214. "id_delimiter": id_delimiter,
  215. "extend_sid": extend_sid,
  216. "auto_increasing_sid": auto_increasing_sid, **kargs}
  217. self._upload_conf.append(data_conf)
  218. def _get_task_inst(self, job_id, name, init_role, party_id):
  219. component = None
  220. if name in self._components:
  221. component = self._components[name]
  222. if component is None:
  223. if self._stage != "predict":
  224. raise ValueError(f"Component {name} does not exist")
  225. training_meta = self._predict_pipeline[0]["pipeline"].get_predict_meta()
  226. component = training_meta.get("components").get(name)
  227. if component is None:
  228. raise ValueError(f"Component {name} does not exist")
  229. return TaskInfo(jobid=job_id,
  230. component=component,
  231. job_client=self._job_invoker,
  232. role=init_role,
  233. party_id=party_id)
  234. @LOGGER.catch(reraise=True)
  235. def get_component(self, component_names=None):
  236. job_id = self._train_job_id
  237. if self._cur_state != "fit":
  238. job_id = self._predict_job_id
  239. init_role = self._initiator.role
  240. party_id = self._initiator.party_id
  241. if not component_names:
  242. component_tasks = {}
  243. for name in self._components:
  244. component_tasks[name] = self._get_task_inst(job_id, name, init_role, party_id)
  245. return component_tasks
  246. elif isinstance(component_names, str):
  247. return self._get_task_inst(job_id, component_names, init_role, party_id)
  248. elif isinstance(component_names, list):
  249. component_tasks = []
  250. for name in component_names:
  251. component_tasks.append(self._get_task_inst(job_id, name, init_role, party_id))
  252. return component_tasks
  253. def _construct_train_dsl(self):
  254. if self._global_job_provider:
  255. self._train_dsl["provider"] = self._global_job_provider
  256. self._train_dsl["components"] = {}
  257. for name, component in self._components.items():
  258. component_dsl = {"module": component.module}
  259. if name in self._components_input:
  260. component_dsl["input"] = self._components_input[name]
  261. if hasattr(component, "output"):
  262. component_dsl["output"] = {}
  263. output_attrs = {"data": "data_output",
  264. "model": "model_output",
  265. "cache": "cache_output"}
  266. for output_key, attr in output_attrs.items():
  267. if hasattr(component.output, attr):
  268. component_dsl["output"][output_key] = getattr(component.output, attr)
  269. provider_name = None
  270. provider_version = None
  271. if not hasattr(component, "source_provider"):
  272. LOGGER.warning(f"Can not retrieval source provider of component {name}, "
  273. f"refer to pipeline/component/component_base.py")
  274. else:
  275. provider_name = getattr(component, "source_provider")
  276. if provider_name is None:
  277. LOGGER.warning(f"Source provider of component {name} is None, "
  278. f"refer to pipeline/component/component_base.py")
  279. if hasattr(component, "provider"):
  280. provider = getattr(component, "provider")
  281. if provider is not None:
  282. if provider.find("@") != -1:
  283. provider_name, provider_version = provider.split("@", -1)
  284. else:
  285. provider_name = provider
  286. # component_dsl["provider"] = provider
  287. if getattr(component, "provider_version") is not None:
  288. provider_version = getattr(component, "provider_version")
  289. if provider_name and provider_version:
  290. component_dsl["provider"] = "@".join([provider_name, provider_version])
  291. elif provider_name:
  292. component_dsl["provider"] = provider_name
  293. self._train_dsl["components"][name] = component_dsl
  294. if not self._train_dsl:
  295. raise ValueError("there are no components to train")
  296. LOGGER.debug(f"train_dsl: {self._train_dsl}")
  297. def _construct_train_conf(self):
  298. self._train_conf["dsl_version"] = VERSION
  299. self._train_conf["initiator"] = self._get_initiator_conf()
  300. self._train_conf["role"] = self._roles
  301. self._train_conf["job_parameters"] = {"common": {"job_type": "train"}}
  302. for name, component in self._components.items():
  303. param_conf = component.get_config(version=VERSION, roles=self._roles)
  304. if "common" in param_conf:
  305. common_param_conf = param_conf["common"]
  306. if "component_parameters" not in self._train_conf:
  307. self._train_conf["component_parameters"] = {}
  308. if "common" not in self._train_conf["component_parameters"]:
  309. self._train_conf["component_parameters"]["common"] = {}
  310. self._train_conf["component_parameters"]["common"].update(common_param_conf)
  311. if "role" in param_conf:
  312. role_param_conf = param_conf["role"]
  313. if "component_parameters" not in self._train_conf:
  314. self._train_conf["component_parameters"] = {}
  315. if "role" not in self._train_conf["component_parameters"]:
  316. self._train_conf["component_parameters"]["role"] = {}
  317. self._train_conf["component_parameters"]["role"] = tools.merge_dict(
  318. role_param_conf, self._train_conf["component_parameters"]["role"])
  319. LOGGER.debug(f"self._train_conf: \n {json.dumps(self._train_conf, indent=4, ensure_ascii=False)}")
  320. return self._train_conf
  321. def _construct_upload_conf(self, data_conf):
  322. upload_conf = copy.deepcopy(data_conf)
  323. # upload_conf["work_mode"] = work_mode
  324. return upload_conf
  325. def describe(self):
  326. LOGGER.info(f"Pipeline Stage is {self._stage}")
  327. LOGGER.info("DSL is:")
  328. if self._stage == "fit":
  329. LOGGER.info(f"{self._train_dsl}")
  330. else:
  331. LOGGER.info(f"{self._predict_dsl}")
  332. LOGGER.info(f"Pipeline Create Time: {self._create_time}")
  333. def get_train_job_id(self):
  334. return self._train_job_id
  335. def get_predict_job_id(self):
  336. return self._predict_job_id
  337. def _set_state(self, state):
  338. self._cur_state = state
  339. def set_job_invoker(self, job_invoker):
  340. self._job_invoker = job_invoker
  341. @LOGGER.catch(reraise=True)
  342. def compile(self):
  343. self._construct_train_dsl()
  344. self._train_conf = self._construct_train_conf()
  345. if self._stage == "predict":
  346. predict_pipeline = self._predict_pipeline[0]["pipeline"]
  347. data_info = self._predict_pipeline[0]["data"]
  348. meta = predict_pipeline.get_predict_meta()
  349. if meta["stage"] == "predict":
  350. raise ValueError(
  351. "adding predict pipeline objects'stage is predict, a predict pipeline cannot be an input component")
  352. self._model_info = meta["model_info"]
  353. predict_pipeline_dsl = meta["predict_dsl"]
  354. predict_pipeline_conf = meta["train_conf"]
  355. if not predict_pipeline_dsl:
  356. raise ValueError(
  357. "Cannot find deploy model in predict pipeline, to use a pipeline as input component, "
  358. "it should be deploy first")
  359. for cpn in self._train_dsl["components"]:
  360. if cpn in predict_pipeline_dsl["components"]:
  361. raise ValueError(
  362. f"component name {cpn} exist in predict pipeline's deploy component, this is not support")
  363. if "algorithm_parameters" in predict_pipeline_conf:
  364. algo_param = predict_pipeline_conf["algorithm_parameters"]
  365. if "algorithm_parameters" in self._train_conf:
  366. for key, value in algo_param.items():
  367. if key not in self._train_conf["algorithm_parameters"]:
  368. self._train_conf["algorithm_parameters"][key] = value
  369. else:
  370. self._train_conf["algorithm_parameters"] = algo_param
  371. if "role_parameters" in predict_pipeline_conf:
  372. role_param = predict_pipeline_conf["role_parameters"]
  373. for cpn in self._train_dsl["components"]:
  374. for role, param in role_param.items():
  375. for idx in param:
  376. if param[idx].get(cpn) is not None:
  377. del predict_pipeline_conf["role_parameters"][role][idx][cpn]
  378. if "role_parameters" not in self._train_conf:
  379. self._train_conf["role_parameters"] = {}
  380. self._train_conf["role_parameters"] = tools.merge_dict(self._train_conf["role_parameters"],
  381. predict_pipeline_conf["role_parameters"])
  382. self._predict_dsl = tools.merge_dict(predict_pipeline_dsl, self._train_dsl)
  383. for data_field, val in data_info.items():
  384. cpn = data_field.split(".", -1)[0]
  385. dataset = data_field.split(".", -1)[1]
  386. if not isinstance(val, list):
  387. val = [val]
  388. if "input" not in self._predict_dsl["components"][cpn]:
  389. self._predict_dsl["components"][cpn]["input"] = {}
  390. if 'data' not in self._predict_dsl["components"][cpn]["input"]:
  391. self._predict_dsl["components"][cpn]["input"]["data"] = {}
  392. self._predict_dsl["components"][cpn]["input"]["data"][dataset] = val
  393. return self
  394. @LOGGER.catch(reraise=True)
  395. def _check_duplicate_setting(self, submit_conf):
  396. system_role = self._system_role
  397. if "role" in submit_conf["job_parameters"]:
  398. role_conf = submit_conf["job_parameters"]["role"]
  399. system_role_conf = role_conf.get(system_role, {})
  400. for party, conf in system_role_conf.items():
  401. if conf.get("user"):
  402. raise ValueError(f"system role {system_role}'s user info already set. Please check.")
  403. def _feed_job_parameters(self, conf, job_type=None,
  404. model_info=None, job_parameters=None):
  405. submit_conf = copy.deepcopy(conf)
  406. LOGGER.debug(f"submit conf type is {type(submit_conf)}")
  407. if job_parameters:
  408. submit_conf["job_parameters"] = job_parameters.get_config(roles=self._roles)
  409. if "common" not in submit_conf["job_parameters"]:
  410. submit_conf["job_parameters"]["common"] = {}
  411. submit_conf["job_parameters"]["common"]["job_type"] = job_type
  412. if model_info is not None:
  413. submit_conf["job_parameters"]["common"]["model_id"] = model_info.model_id
  414. submit_conf["job_parameters"]["common"]["model_version"] = model_info.model_version
  415. if self._system_role:
  416. self._check_duplicate_setting(submit_conf)
  417. init_role = self._initiator.role
  418. idx = str(self._roles[init_role].index(self._initiator.party_id))
  419. if "role" not in submit_conf["job_parameters"]:
  420. submit_conf["job_parameters"]["role"] = {}
  421. if init_role not in submit_conf["job_parameters"]["role"]:
  422. submit_conf["job_parameters"]["role"][init_role] = {}
  423. if idx not in submit_conf["job_parameters"]["role"][init_role]:
  424. submit_conf["job_parameters"]["role"][init_role][idx] = {}
  425. submit_conf["job_parameters"]["role"][init_role][idx].update({"user": getpass.getuser()})
  426. return submit_conf
  427. def _filter_out_deploy_component(self, predict_conf):
  428. if "component_parameters" not in predict_conf:
  429. return predict_conf
  430. if "common" in predict_conf["component_parameters"]:
  431. cpns = list(predict_conf["component_parameters"]["common"])
  432. for cpn in cpns:
  433. if cpn not in self._components.keys():
  434. del predict_conf["component_parameters"]["common"]
  435. if "role" in predict_conf["component_parameters"]:
  436. roles = predict_conf["component_parameters"]["role"].keys()
  437. for role in roles:
  438. role_params = predict_conf["component_parameters"]["role"].get(role)
  439. indexs = role_params.keys()
  440. for idx in indexs:
  441. cpns = role_params[idx].keys()
  442. for cpn in cpns:
  443. if cpn not in self._components.keys():
  444. del role_params[idx][cpn]
  445. if not role_params[idx]:
  446. del role_params[idx]
  447. if role_params:
  448. predict_conf["component_parameters"]["role"][role] = role_params
  449. else:
  450. del predict_conf["component_parameters"]["role"][role]
  451. return predict_conf
  452. @LOGGER.catch(reraise=True)
  453. def fit(self, job_parameters=None, callback_func=None):
  454. if self._stage == "predict":
  455. raise ValueError("This pipeline is constructed for predicting, cannot use fit interface")
  456. if job_parameters and not isinstance(job_parameters, JobParameters):
  457. raise ValueError("input parameter of fit function should be JobParameters object")
  458. LOGGER.debug(f"in fit, _train_conf is: \n {json.dumps(self._train_conf)}")
  459. self._set_state("fit")
  460. training_conf = self._feed_job_parameters(self._train_conf, job_type="train", job_parameters=job_parameters)
  461. self._train_conf = training_conf
  462. LOGGER.debug(f"train_conf is: \n {json.dumps(training_conf, indent=4, ensure_ascii=False)}")
  463. self._train_job_id, detail_info = self._job_invoker.submit_job(self._train_dsl, training_conf, callback_func)
  464. self._train_board_url = detail_info["board_url"]
  465. self._model_info = SimpleNamespace(model_id=detail_info["model_info"]["model_id"],
  466. model_version=detail_info["model_info"]["model_version"])
  467. self._fit_status = self._job_invoker.monitor_job_status(self._train_job_id,
  468. self._initiator.role,
  469. self._initiator.party_id)
  470. @LOGGER.catch(reraise=True)
  471. def update_model_info(self, model_id=None, model_version=None):
  472. # predict pipeline
  473. if self._predict_pipeline:
  474. predict_pipeline = self._predict_pipeline[0]["pipeline"]
  475. if model_id:
  476. predict_pipeline._model_info.model_id = model_id
  477. if model_version:
  478. predict_pipeline._model_info.model_version = model_version
  479. return self
  480. # train pipeline
  481. original_model_id, original_model_version = None, None
  482. if self._model_info is not None:
  483. original_model_id, original_model_version = self._model_info.model_id, self._model_info.model_version
  484. new_model_id = model_id if model_id is not None else original_model_id
  485. new_model_version = model_version if model_version is not None else original_model_version
  486. if new_model_id is None and new_model_version is None:
  487. return self
  488. self._model_info = SimpleNamespace(model_id=new_model_id, model_version=new_model_version)
  489. return self
  490. @LOGGER.catch(reraise=True)
  491. def continuously_fit(self):
  492. self._fit_status = self._job_invoker.monitor_job_status(self._train_job_id,
  493. self._initiator.role,
  494. self._initiator.party_id,
  495. previous_status=self._fit_status)
  496. @LOGGER.catch(reraise=True)
  497. def predict(self, job_parameters=None, components_checkpoint=None):
  498. """
  499. Parameters
  500. ----------
  501. job_parameters: None
  502. components_checkpoint: specify which model to take, ex.: {"hetero_lr_0": {"step_index": 8}}
  503. Returns
  504. -------
  505. """
  506. if self._stage != "predict":
  507. raise ValueError(
  508. "To use predict function, please deploy component(s) from training pipeline"
  509. "and construct a new predict pipeline with data reader and training pipeline.")
  510. if job_parameters and not isinstance(job_parameters, JobParameters):
  511. raise ValueError("input parameter of fit function should be JobParameters object")
  512. self.compile()
  513. res_dict = self._job_invoker.model_deploy(model_id=self._model_info.model_id,
  514. model_version=self._model_info.model_version,
  515. predict_dsl=self._predict_dsl,
  516. components_checkpoint=components_checkpoint)
  517. self._predict_model_info = SimpleNamespace(model_id=res_dict["model_id"],
  518. model_version=res_dict["model_version"])
  519. predict_conf = self._feed_job_parameters(self._train_conf,
  520. job_type="predict",
  521. model_info=self._predict_model_info,
  522. job_parameters=job_parameters)
  523. predict_conf = self._filter_out_deploy_component(predict_conf)
  524. self._predict_conf = copy.deepcopy(predict_conf)
  525. predict_dsl = copy.deepcopy(self._predict_dsl)
  526. self._predict_job_id, _ = self._job_invoker.submit_job(dsl=predict_dsl, submit_conf=predict_conf)
  527. self._job_invoker.monitor_job_status(self._predict_job_id,
  528. self._initiator.role,
  529. self._initiator.party_id)
  530. @LOGGER.catch(reraise=True)
  531. def upload(self, drop=0):
  532. for data_conf in self._upload_conf:
  533. upload_conf = self._construct_upload_conf(data_conf)
  534. LOGGER.debug(f"upload_conf is {json.dumps(upload_conf)}")
  535. self._train_job_id, detail_info = self._job_invoker.upload_data(upload_conf, int(drop))
  536. self._train_board_url = detail_info["board_url"]
  537. self._job_invoker.monitor_job_status(self._train_job_id,
  538. "local",
  539. 0)
  540. @LOGGER.catch(reraise=True)
  541. def dump(self, file_path=None):
  542. pkl = pickle.dumps(self)
  543. if file_path is not None:
  544. with open(file_path, "wb") as fout:
  545. fout.write(pkl)
  546. return pkl
  547. @classmethod
  548. def load(cls, pipeline_bytes):
  549. """
  550. return pickle.loads(pipeline_bytes)
  551. """
  552. pipeline_obj = pickle.loads(pipeline_bytes)
  553. pipeline_obj.set_job_invoker(JobInvoker())
  554. return pipeline_obj
  555. @classmethod
  556. def load_model_from_file(cls, file_path):
  557. with open(file_path, "rb") as fin:
  558. pipeline_obj = pickle.loads(fin.read())
  559. pipeline_obj.set_job_invoker(JobInvoker())
  560. return pipeline_obj
  561. @LOGGER.catch(reraise=True)
  562. def deploy_component(self, components=None):
  563. if self._train_dsl is None:
  564. raise ValueError("Before deploy model, training should be finished!!!")
  565. if components is None:
  566. components = self._components
  567. deploy_cpns = []
  568. for cpn in components:
  569. if isinstance(cpn, str):
  570. deploy_cpns.append(cpn)
  571. elif isinstance(cpn, Component):
  572. deploy_cpns.append(cpn.name)
  573. else:
  574. raise ValueError(
  575. "deploy component parameters is wrong, expect str or Component object, but {} find".format(
  576. type(cpn)))
  577. if deploy_cpns[-1] not in self._components:
  578. raise ValueError("Component {} does not exist in pipeline".format(deploy_cpns[-1]))
  579. if isinstance(self._components.get(deploy_cpns[-1]), Reader):
  580. raise ValueError("Reader should not be include in predict pipeline")
  581. res_dict = self._job_invoker.model_deploy(model_id=self._model_info.model_id,
  582. model_version=self._model_info.model_version,
  583. cpn_list=deploy_cpns)
  584. self._predict_model_info = SimpleNamespace(model_id=res_dict["model_id"],
  585. model_version=res_dict["model_version"])
  586. self._predict_dsl = self._job_invoker.get_predict_dsl(model_id=res_dict["model_id"],
  587. model_version=res_dict["model_version"])
  588. if self._predict_dsl:
  589. self._deploy = True
  590. return self
  591. def is_deploy(self):
  592. return self._deploy
  593. def is_load(self):
  594. return self._load
  595. @LOGGER.catch(reraise=True)
  596. def init_predict_config(self, config):
  597. if isinstance(config, PipeLine):
  598. config = config.get_predict_meta()
  599. self._stage = "predict"
  600. self._model_info = config["model_info"]
  601. self._predict_dsl = config["predict_dsl"]
  602. self._train_conf = config["train_conf"]
  603. self._initiator = config["initiator"]
  604. self._train_components = config["train_components"]
  605. @LOGGER.catch(reraise=True)
  606. def get_component_input_msg(self):
  607. if VERSION != 2:
  608. raise ValueError("In DSL Version 1,only need to config data from args, do not need special component")
  609. need_input = {}
  610. for cpn_name, config in self._predict_dsl["components"].items():
  611. if "input" not in config:
  612. continue
  613. if "data" not in config["input"]:
  614. continue
  615. data_config = config["input"]["data"]
  616. for data_type, dataset_list in data_config.items():
  617. for data_set in dataset_list:
  618. input_cpn = data_set.split(".", -1)[0]
  619. input_inst = self._components[input_cpn]
  620. if isinstance(input_inst, Reader):
  621. if cpn_name not in need_input:
  622. need_input[cpn_name] = {}
  623. need_input[cpn_name][data_type] = []
  624. need_input[cpn_name][data_type].append(input_cpn)
  625. return need_input
  626. @LOGGER.catch(reraise=True)
  627. def get_input_reader_placeholder(self):
  628. input_info = self.get_component_input_msg()
  629. input_placeholder = set()
  630. for cpn_name, data_dict in input_info.items():
  631. for data_type, dataset_list in data_dict.items():
  632. for dataset in dataset_list:
  633. input_placeholder.add(dataset)
  634. return input_placeholder
  635. @LOGGER.catch(reraise=True)
  636. def set_inputs(self, data_dict):
  637. if not isinstance(data_dict, dict):
  638. raise ValueError(
  639. "inputs for predicting should be a dict, key is input_placeholder name, value is a reader object")
  640. unfilled_placeholder = self.get_input_reader_placeholder() - set(data_dict.keys())
  641. if unfilled_placeholder:
  642. raise ValueError("input placeholder {} should be fill".format(unfilled_placeholder))
  643. self._data_to_feed_in_prediction = data_dict
  644. @LOGGER.catch(reraise=True)
  645. def bind_table(self, name, namespace, path, engine='PATH', replace=True, **kwargs):
  646. info = self._job_invoker.bind_table(engine=engine, name=name, namespace=namespace, address={
  647. "path": path
  648. }, drop=replace, **kwargs)
  649. return info
  650. # @LOGGER.catch(reraise=True)
  651. def __getattr__(self, attr):
  652. if attr in self._components:
  653. return self._components[attr]
  654. return self.__getattribute__(attr)
  655. @LOGGER.catch(reraise=True)
  656. def __getitem__(self, item):
  657. if item not in self._components:
  658. raise ValueError("Pipeline does not has component }{}".format(item))
  659. return self._components[item]
  660. def __getstate__(self):
  661. return vars(self)
  662. def __setstate__(self, state):
  663. vars(self).update(state)