job_controller.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import shutil
  18. from fate_arch.common import EngineType, engine_utils
  19. from fate_arch.common.base_utils import current_timestamp, json_dumps
  20. from fate_arch.computing import ComputingEngine
  21. from fate_flow.controller.task_controller import TaskController
  22. from fate_flow.db.db_models import PipelineComponentMeta
  23. from fate_flow.db.job_default_config import JobDefaultConfig
  24. from fate_flow.db.runtime_config import RuntimeConfig
  25. from fate_flow.entity import RunParameters
  26. from fate_flow.entity.run_status import EndStatus, JobInheritanceStatus, JobStatus, TaskStatus
  27. from fate_flow.entity.types import RetCode, WorkerName
  28. from fate_flow.manager.provider_manager import ProviderManager
  29. from fate_flow.manager.resource_manager import ResourceManager
  30. from fate_flow.manager.worker_manager import WorkerManager
  31. from fate_flow.model.checkpoint import CheckpointManager
  32. from fate_flow.model.sync_model import SyncComponent
  33. from fate_flow.operation.job_saver import JobSaver
  34. from fate_flow.operation.job_tracker import Tracker
  35. from fate_flow.pipelined_model.pipelined_model import PipelinedComponent
  36. from fate_flow.protobuf.python import pipeline_pb2
  37. from fate_flow.scheduler.federated_scheduler import FederatedScheduler
  38. from fate_flow.settings import ENABLE_MODEL_STORE, ENGINES
  39. from fate_flow.utils import data_utils, job_utils, log_utils, schedule_utils
  40. from fate_flow.utils.job_utils import get_job_dataset
  41. from fate_flow.utils.log_utils import schedule_logger
  42. from fate_flow.utils.model_utils import gather_model_info_data, save_model_info
  43. class JobController(object):
  44. @classmethod
  45. def create_job(cls, job_id, role, party_id, job_info):
  46. # parse job configuration
  47. dsl = job_info['dsl']
  48. runtime_conf = job_info['runtime_conf']
  49. train_runtime_conf = job_info['train_runtime_conf']
  50. dsl_parser = schedule_utils.get_job_dsl_parser(
  51. dsl=dsl,
  52. runtime_conf=runtime_conf,
  53. train_runtime_conf=train_runtime_conf
  54. )
  55. job_parameters = dsl_parser.get_job_parameters(runtime_conf)
  56. schedule_logger(job_id).info('job parameters:{}'.format(job_parameters))
  57. dest_user = job_parameters.get(role, {}).get(party_id, {}).get('user', '')
  58. user = {}
  59. src_party_id = int(job_info['src_party_id']) if job_info.get('src_party_id') else 0
  60. src_role = job_info.get('src_role', '')
  61. src_user = job_parameters.get(src_role, {}).get(src_party_id, {}).get('user', '') if src_role else ''
  62. for _role, party_id_item in job_parameters.items():
  63. user[_role] = {}
  64. for _party_id, _parameters in party_id_item.items():
  65. user[_role][_party_id] = _parameters.get("user", "")
  66. job_parameters = RunParameters(**job_parameters.get(role, {}).get(party_id, {}))
  67. # save new job into db
  68. if role == job_info["initiator_role"] and party_id == job_info["initiator_party_id"]:
  69. is_initiator = True
  70. else:
  71. is_initiator = False
  72. job_info["status"] = JobStatus.READY
  73. job_info["user_id"] = dest_user
  74. job_info["src_user"] = src_user
  75. job_info["user"] = user
  76. # this party configuration
  77. job_info["role"] = role
  78. job_info["party_id"] = party_id
  79. job_info["is_initiator"] = is_initiator
  80. job_info["progress"] = 0
  81. cls.create_job_parameters_on_party(role=role, party_id=party_id, job_parameters=job_parameters)
  82. # update job parameters on party
  83. job_info["runtime_conf_on_party"]["job_parameters"] = job_parameters.to_dict()
  84. JobSaver.create_job(job_info=job_info)
  85. schedule_logger(job_id).info("start initialize tasks")
  86. initialized_result, provider_group = cls.initialize_tasks(job_id=job_id,
  87. role=role,
  88. party_id=party_id,
  89. run_on_this_party=True,
  90. initiator_role=job_info["initiator_role"],
  91. initiator_party_id=job_info["initiator_party_id"],
  92. job_parameters=job_parameters,
  93. dsl_parser=dsl_parser,
  94. runtime_conf=runtime_conf,
  95. check_version=True)
  96. schedule_logger(job_id).info("initialize tasks success")
  97. for provider_key, group_info in provider_group.items():
  98. for cpn in group_info["components"]:
  99. dsl["components"][cpn]["provider"] = provider_key
  100. roles = job_info['roles']
  101. cls.initialize_job_tracker(job_id=job_id, role=role, party_id=party_id,
  102. job_parameters=job_parameters, roles=roles, is_initiator=is_initiator, dsl_parser=dsl_parser)
  103. job_utils.save_job_conf(job_id=job_id,
  104. role=role,
  105. party_id=party_id,
  106. dsl=dsl,
  107. runtime_conf=runtime_conf,
  108. runtime_conf_on_party=job_info["runtime_conf_on_party"],
  109. train_runtime_conf=train_runtime_conf,
  110. pipeline_dsl=None)
  111. return {"components": initialized_result}
  112. @classmethod
  113. def set_federated_mode(cls, job_parameters: RunParameters):
  114. if not job_parameters.federated_mode:
  115. job_parameters.federated_mode = ENGINES["federated_mode"]
  116. @classmethod
  117. def set_engines(cls, job_parameters: RunParameters, engine_type=None):
  118. engines = engine_utils.get_engines()
  119. if not engine_type:
  120. engine_type = {EngineType.COMPUTING, EngineType.FEDERATION, EngineType.STORAGE}
  121. for k in engine_type:
  122. setattr(job_parameters, f"{k}_engine", engines[k])
  123. @classmethod
  124. def create_common_job_parameters(cls, job_id, initiator_role, common_job_parameters: RunParameters):
  125. JobController.set_federated_mode(job_parameters=common_job_parameters)
  126. JobController.set_engines(job_parameters=common_job_parameters, engine_type={EngineType.COMPUTING})
  127. JobController.fill_default_job_parameters(job_id=job_id, job_parameters=common_job_parameters)
  128. JobController.adapt_job_parameters(role=initiator_role, job_parameters=common_job_parameters, create_initiator_baseline=True)
  129. @classmethod
  130. def create_job_parameters_on_party(cls, role, party_id, job_parameters: RunParameters):
  131. JobController.set_engines(job_parameters=job_parameters)
  132. cls.fill_party_specific_parameters(role=role,
  133. party_id=party_id,
  134. job_parameters=job_parameters)
  135. @classmethod
  136. def fill_party_specific_parameters(cls, role, party_id, job_parameters: RunParameters):
  137. cls.adapt_job_parameters(role=role, job_parameters=job_parameters)
  138. engines_info = cls.get_job_engines_address(job_parameters=job_parameters)
  139. cls.check_parameters(job_parameters=job_parameters,
  140. role=role, party_id=party_id, engines_info=engines_info)
  141. @classmethod
  142. def fill_default_job_parameters(cls, job_id, job_parameters: RunParameters):
  143. keys = {"task_parallelism", "auto_retries", "auto_retry_delay", "federated_status_collect_type"}
  144. for key in keys:
  145. if hasattr(job_parameters, key) and getattr(job_parameters, key) is None:
  146. if hasattr(JobDefaultConfig, key):
  147. setattr(job_parameters, key, getattr(JobDefaultConfig, key))
  148. else:
  149. schedule_logger(job_id).warning(f"can not found {key} job parameter default value from job_default_settings")
  150. @classmethod
  151. def adapt_job_parameters(cls, role, job_parameters: RunParameters, create_initiator_baseline=False):
  152. ResourceManager.adapt_engine_parameters(
  153. role=role, job_parameters=job_parameters, create_initiator_baseline=create_initiator_baseline)
  154. if create_initiator_baseline:
  155. if job_parameters.task_parallelism is None:
  156. job_parameters.task_parallelism = JobDefaultConfig.task_parallelism
  157. if job_parameters.federated_status_collect_type is None:
  158. job_parameters.federated_status_collect_type = JobDefaultConfig.federated_status_collect_type
  159. if create_initiator_baseline and not job_parameters.computing_partitions:
  160. job_parameters.computing_partitions = job_parameters.adaptation_parameters[
  161. "task_cores_per_node"] * job_parameters.adaptation_parameters["task_nodes"]
  162. @classmethod
  163. def get_job_engines_address(cls, job_parameters: RunParameters):
  164. engines_info = {}
  165. engine_list = [
  166. (EngineType.COMPUTING, job_parameters.computing_engine),
  167. (EngineType.FEDERATION, job_parameters.federation_engine),
  168. (EngineType.STORAGE, job_parameters.storage_engine)
  169. ]
  170. for engine_type, engine_name in engine_list:
  171. engine_info = ResourceManager.get_engine_registration_info(
  172. engine_type=engine_type, engine_name=engine_name)
  173. job_parameters.engines_address[engine_type] = engine_info.f_engine_config if engine_info else {}
  174. engines_info[engine_type] = engine_info
  175. return engines_info
  176. @classmethod
  177. def check_parameters(cls, job_parameters: RunParameters, role, party_id, engines_info):
  178. status, cores_submit, max_cores_per_job = ResourceManager.check_resource_apply(
  179. job_parameters=job_parameters, role=role, party_id=party_id, engines_info=engines_info)
  180. if not status:
  181. msg = ""
  182. msg2 = "default value is fate_flow/settings.py#DEFAULT_TASK_CORES_PER_NODE, refer fate_flow/examples/simple/simple_job_conf.json"
  183. if job_parameters.computing_engine in {ComputingEngine.EGGROLL, ComputingEngine.STANDALONE}:
  184. msg = "please use task_cores job parameters to set request task cores or you can customize it with eggroll_run job parameters"
  185. elif job_parameters.computing_engine in {ComputingEngine.SPARK}:
  186. msg = "please use task_cores job parameters to set request task cores or you can customize it with spark_run job parameters"
  187. raise RuntimeError(
  188. f"max cores per job is {max_cores_per_job} base on (fate_flow/settings#MAX_CORES_PERCENT_PER_JOB * conf/service_conf.yaml#nodes * conf/service_conf.yaml#cores_per_node), expect {cores_submit} cores, {msg}, {msg2}")
  189. @classmethod
  190. def gen_updated_parameters(cls, job_id, initiator_role, initiator_party_id, input_job_parameters, input_component_parameters):
  191. # todo: check can not update job parameters
  192. job_configuration = job_utils.get_job_configuration(job_id=job_id,
  193. role=initiator_role,
  194. party_id=initiator_party_id)
  195. updated_job_parameters = job_configuration.runtime_conf["job_parameters"]
  196. updated_component_parameters = job_configuration.runtime_conf["component_parameters"]
  197. if input_job_parameters:
  198. if input_job_parameters.get("common"):
  199. common_job_parameters = RunParameters(**input_job_parameters["common"])
  200. cls.create_common_job_parameters(job_id=job_id, initiator_role=initiator_role, common_job_parameters=common_job_parameters)
  201. for attr in {"model_id", "model_version"}:
  202. setattr(common_job_parameters, attr, updated_job_parameters["common"].get(attr))
  203. updated_job_parameters["common"] = common_job_parameters.to_dict()
  204. # not support role
  205. updated_components = set()
  206. if input_component_parameters:
  207. cls.merge_update(input_component_parameters, updated_component_parameters)
  208. return updated_job_parameters, updated_component_parameters, list(updated_components)
  209. @classmethod
  210. def merge_update(cls, inputs: dict, results: dict):
  211. if not isinstance(inputs, dict) or not isinstance(results, dict):
  212. raise ValueError(f"must both dict, but {type(inputs)} inputs and {type(results)} results")
  213. for k, v in inputs.items():
  214. if k not in results:
  215. results[k] = v
  216. elif isinstance(v, dict):
  217. cls.merge_update(v, results[k])
  218. else:
  219. results[k] = v
  220. @classmethod
  221. def update_parameter(cls, job_id, role, party_id, updated_parameters: dict):
  222. job_configuration = job_utils.get_job_configuration(job_id=job_id,
  223. role=role,
  224. party_id=party_id)
  225. job_parameters = updated_parameters.get("job_parameters")
  226. component_parameters = updated_parameters.get("component_parameters")
  227. if job_parameters:
  228. job_configuration.runtime_conf["job_parameters"] = job_parameters
  229. job_parameters = RunParameters(**job_parameters["common"])
  230. cls.create_job_parameters_on_party(role=role,
  231. party_id=party_id,
  232. job_parameters=job_parameters)
  233. job_configuration.runtime_conf_on_party["job_parameters"] = job_parameters.to_dict()
  234. if component_parameters:
  235. job_configuration.runtime_conf["component_parameters"] = component_parameters
  236. job_configuration.runtime_conf_on_party["component_parameters"] = component_parameters
  237. job_info = {}
  238. job_info["job_id"] = job_id
  239. job_info["role"] = role
  240. job_info["party_id"] = party_id
  241. job_info["runtime_conf"] = job_configuration.runtime_conf
  242. job_info["runtime_conf_on_party"] = job_configuration.runtime_conf_on_party
  243. JobSaver.update_job(job_info)
  244. @classmethod
  245. def initialize_task(cls, role, party_id, task_info: dict):
  246. task_info["role"] = role
  247. task_info["party_id"] = party_id
  248. initialized_result, provider_group = cls.initialize_tasks(components=[task_info["component_name"]], **task_info)
  249. return initialized_result
  250. @classmethod
  251. def initialize_tasks(cls, job_id, role, party_id, run_on_this_party, initiator_role, initiator_party_id,
  252. job_parameters: RunParameters = None, dsl_parser=None, components: list = None,
  253. runtime_conf=None, check_version=False, is_scheduler=False, **kwargs):
  254. common_task_info = {}
  255. common_task_info["job_id"] = job_id
  256. common_task_info["initiator_role"] = initiator_role
  257. common_task_info["initiator_party_id"] = initiator_party_id
  258. common_task_info["role"] = role
  259. common_task_info["party_id"] = party_id
  260. common_task_info["run_on_this_party"] = run_on_this_party
  261. common_task_info["federated_mode"] = kwargs.get("federated_mode", job_parameters.federated_mode if job_parameters else None)
  262. common_task_info["federated_status_collect_type"] = kwargs.get("federated_status_collect_type", job_parameters.federated_status_collect_type if job_parameters else None)
  263. common_task_info["auto_retries"] = kwargs.get("auto_retries", job_parameters.auto_retries if job_parameters else None)
  264. common_task_info["auto_retry_delay"] = kwargs.get("auto_retry_delay", job_parameters.auto_retry_delay if job_parameters else None)
  265. common_task_info["task_version"] = kwargs.get("task_version")
  266. if role == "local":
  267. common_task_info["run_ip"] = RuntimeConfig.JOB_SERVER_HOST
  268. common_task_info["run_port"] = RuntimeConfig.HTTP_PORT
  269. if dsl_parser is None:
  270. dsl_parser, runtime_conf, dsl = schedule_utils.get_job_dsl_parser_by_job_id(job_id)
  271. provider_group = ProviderManager.get_job_provider_group(dsl_parser=dsl_parser,
  272. runtime_conf=runtime_conf,
  273. components=components,
  274. role=role,
  275. party_id=party_id,
  276. check_version=check_version,
  277. is_scheduler=is_scheduler)
  278. initialized_result = {}
  279. for group_key, group_info in provider_group.items():
  280. initialized_config = {}
  281. initialized_config.update(group_info)
  282. initialized_config["common_task_info"] = common_task_info
  283. if run_on_this_party:
  284. code, _result = WorkerManager.start_general_worker(worker_name=WorkerName.TASK_INITIALIZER,
  285. job_id=job_id,
  286. role=role,
  287. party_id=party_id,
  288. initialized_config=initialized_config,
  289. run_in_subprocess=False if initialized_config["if_default_provider"] else True)
  290. initialized_result.update(_result)
  291. else:
  292. cls.initialize_task_holder_for_scheduling(role=role,
  293. party_id=party_id,
  294. components=initialized_config["components"],
  295. common_task_info=common_task_info,
  296. provider_info=initialized_config["provider"])
  297. return initialized_result, provider_group
  298. @classmethod
  299. def initialize_task_holder_for_scheduling(cls, role, party_id, components, common_task_info, provider_info):
  300. for component_name in components:
  301. task_info = {}
  302. task_info.update(common_task_info)
  303. task_info["component_name"] = component_name
  304. task_info["component_module"] = ""
  305. task_info["provider_info"] = provider_info
  306. task_info["component_parameters"] = {}
  307. TaskController.create_task(role=role, party_id=party_id,
  308. run_on_this_party=common_task_info["run_on_this_party"],
  309. task_info=task_info)
  310. @classmethod
  311. def initialize_job_tracker(cls, job_id, role, party_id, job_parameters: RunParameters, roles, is_initiator, dsl_parser):
  312. tracker = Tracker(job_id=job_id, role=role, party_id=party_id,
  313. model_id=job_parameters.model_id,
  314. model_version=job_parameters.model_version,
  315. job_parameters=job_parameters)
  316. partner = {}
  317. show_role = {}
  318. for _role, _role_party in roles.items():
  319. if is_initiator or _role == role:
  320. show_role[_role] = show_role.get(_role, [])
  321. for _party_id in _role_party:
  322. if is_initiator or _party_id == party_id:
  323. show_role[_role].append(_party_id)
  324. if _role != role:
  325. partner[_role] = partner.get(_role, [])
  326. partner[_role].extend(_role_party)
  327. else:
  328. for _party_id in _role_party:
  329. if _party_id != party_id:
  330. partner[_role] = partner.get(_role, [])
  331. partner[_role].append(_party_id)
  332. job_args = dsl_parser.get_args_input()
  333. dataset = get_job_dataset(is_initiator, role, party_id, roles, job_args)
  334. tracker.log_job_view({'partner': partner, 'dataset': dataset, 'roles': show_role})
  335. @classmethod
  336. def query_job_input_args(cls, input_data, role, party_id):
  337. min_partition = data_utils.get_input_data_min_partitions(
  338. input_data, role, party_id)
  339. return {'min_input_data_partition': min_partition}
  340. @classmethod
  341. def align_job_args(cls, job_id, role, party_id, job_info):
  342. job_info["job_id"] = job_id
  343. job_info["role"] = role
  344. job_info["party_id"] = party_id
  345. JobSaver.update_job(job_info)
  346. @classmethod
  347. def start_job(cls, job_id, role, party_id, extra_info=None):
  348. schedule_logger(job_id).info(
  349. f"try to start job on {role} {party_id}")
  350. job_info = {
  351. "job_id": job_id,
  352. "role": role,
  353. "party_id": party_id,
  354. "status": JobStatus.RUNNING,
  355. "start_time": current_timestamp()
  356. }
  357. if extra_info:
  358. schedule_logger(job_id).info(f"extra info: {extra_info}")
  359. job_info.update(extra_info)
  360. cls.update_job_status(job_info=job_info)
  361. cls.update_job(job_info=job_info)
  362. schedule_logger(job_id).info(
  363. f"start job on {role} {party_id} successfully")
  364. @classmethod
  365. def update_job(cls, job_info):
  366. """
  367. Save to local database
  368. :param job_info:
  369. :return:
  370. """
  371. return JobSaver.update_job(job_info=job_info)
  372. @classmethod
  373. def update_job_status(cls, job_info):
  374. update_status = JobSaver.update_job_status(job_info=job_info)
  375. if update_status and EndStatus.contains(job_info.get("status")):
  376. ResourceManager.return_job_resource(
  377. job_id=job_info["job_id"], role=job_info["role"], party_id=job_info["party_id"])
  378. return update_status
  379. @classmethod
  380. def stop_jobs(cls, job_id, stop_status, role=None, party_id=None):
  381. if role and party_id:
  382. jobs = JobSaver.query_job(
  383. job_id=job_id, role=role, party_id=party_id)
  384. else:
  385. jobs = JobSaver.query_job(job_id=job_id)
  386. kill_status = True
  387. kill_details = {}
  388. for job in jobs:
  389. kill_job_status, kill_job_details = cls.stop_job(
  390. job=job, stop_status=stop_status)
  391. kill_status = kill_status & kill_job_status
  392. kill_details[job_id] = kill_job_details
  393. return kill_status, kill_details
  394. @classmethod
  395. def stop_job(cls, job, stop_status):
  396. tasks = JobSaver.query_task(
  397. job_id=job.f_job_id, role=job.f_role, party_id=job.f_party_id, only_latest=True, reverse=True)
  398. kill_status = True
  399. kill_details = {}
  400. for task in tasks:
  401. if task.f_status in [TaskStatus.SUCCESS, TaskStatus.WAITING, TaskStatus.PASS]:
  402. continue
  403. kill_task_status = False
  404. status, response = FederatedScheduler.stop_task(job=job, task=task, stop_status=stop_status)
  405. if status == RetCode.SUCCESS:
  406. kill_task_status = True
  407. kill_status = kill_status & kill_task_status
  408. kill_details[task.f_task_id] = 'success' if kill_task_status else 'failed'
  409. if kill_status:
  410. job_info = job.to_human_model_dict(only_primary_with=["status"])
  411. job_info["status"] = stop_status
  412. JobController.update_job_status(job_info)
  413. return kill_status, kill_details
  414. # Job status depends on the final operation result and initiator calculate
  415. @classmethod
  416. def save_pipelined_model(cls, job_id, role, party_id):
  417. if role == 'local':
  418. schedule_logger(job_id).info('A job of local role does not need to save pipeline model')
  419. return
  420. schedule_logger(job_id).info(f'start to save pipeline model on {role} {party_id}')
  421. job_configuration = job_utils.get_job_configuration(job_id, role, party_id)
  422. runtime_conf_on_party = job_configuration.runtime_conf_on_party
  423. job_parameters = runtime_conf_on_party['job_parameters']
  424. model_id = job_parameters['model_id']
  425. model_version = job_parameters['model_version']
  426. job_type = job_parameters.get('job_type', '')
  427. roles = runtime_conf_on_party['role']
  428. initiator_role = runtime_conf_on_party['initiator']['role']
  429. initiator_party_id = runtime_conf_on_party['initiator']['party_id']
  430. assistant_role = job_parameters.get('assistant_role', [])
  431. if role in set(assistant_role) or job_type == 'predict':
  432. return
  433. dsl_parser = schedule_utils.get_job_dsl_parser(
  434. dsl=job_configuration.dsl,
  435. runtime_conf=job_configuration.runtime_conf,
  436. train_runtime_conf=job_configuration.train_runtime_conf,
  437. )
  438. tasks = JobSaver.query_task(
  439. job_id=job_id,
  440. role=role,
  441. party_id=party_id,
  442. only_latest=True,
  443. )
  444. components_parameters = {
  445. task.f_component_name: task.f_component_parameters for task in tasks
  446. }
  447. predict_dsl = schedule_utils.fill_inference_dsl(dsl_parser, job_configuration.dsl, components_parameters)
  448. pipeline = pipeline_pb2.Pipeline()
  449. pipeline.roles = json_dumps(roles, byte=True)
  450. pipeline.model_id = model_id
  451. pipeline.model_version = model_version
  452. pipeline.initiator_role = initiator_role
  453. pipeline.initiator_party_id = initiator_party_id
  454. pipeline.train_dsl = json_dumps(job_configuration.dsl, byte=True)
  455. pipeline.train_runtime_conf = json_dumps(job_configuration.runtime_conf, byte=True)
  456. pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party, byte=True)
  457. pipeline.inference_dsl = json_dumps(predict_dsl, byte=True)
  458. pipeline.fate_version = RuntimeConfig.get_env('FATE')
  459. pipeline.parent = True
  460. pipeline.parent_info = json_dumps({}, byte=True)
  461. pipeline.loaded_times = 0
  462. tracker = Tracker(
  463. job_id=job_id, role=role, party_id=party_id,
  464. model_id=model_id, model_version=model_version,
  465. job_parameters=RunParameters(**job_parameters),
  466. )
  467. if ENABLE_MODEL_STORE:
  468. query = tracker.pipelined_model.pipelined_component.get_define_meta_from_db()
  469. for row in query:
  470. sync_component = SyncComponent(
  471. role=role, party_id=party_id,
  472. model_id=model_id, model_version=model_version,
  473. component_name=row.f_component_name,
  474. )
  475. if not sync_component.local_exists() and sync_component.remote_exists():
  476. sync_component.download()
  477. tracker.pipelined_model.save_pipeline_model(pipeline)
  478. model_info = gather_model_info_data(tracker.pipelined_model)
  479. save_model_info(model_info)
  480. schedule_logger(job_id).info(f'save pipeline on {role} {party_id} successfully')
  481. @classmethod
  482. def clean_job(cls, job_id, role, party_id, roles):
  483. pass
  484. # schedule_logger(job_id).info(f"start to clean job on {role} {party_id}")
  485. # TODO: clean job
  486. # schedule_logger(job_id).info(f"job on {role} {party_id} clean done")
  487. @classmethod
  488. def job_reload(cls, job):
  489. schedule_logger(job.f_job_id).info(f"start job reload")
  490. cls.log_reload(job)
  491. source_inheritance_tasks, target_inheritance_tasks = cls.load_source_target_tasks(job)
  492. schedule_logger(job.f_job_id).info(f"source_inheritance_tasks:{source_inheritance_tasks}, target_inheritance_tasks:{target_inheritance_tasks}")
  493. cls.output_reload(job, source_inheritance_tasks, target_inheritance_tasks)
  494. if job.f_is_initiator:
  495. source_inheritance_tasks, target_inheritance_tasks = cls.load_source_target_tasks(job, update_status=True)
  496. cls.status_reload(job, source_inheritance_tasks, target_inheritance_tasks)
  497. @classmethod
  498. def load_source_target_tasks(cls, job, update_status=False):
  499. filters = {"component_list": job.f_inheritance_info.get("component_list", [])}
  500. if not update_status:
  501. filters.update({"role": job.f_role, "party_id": job.f_party_id})
  502. source_inheritance_tasks = cls.load_tasks(job_id=job.f_inheritance_info.get("job_id"), **filters)
  503. target_inheritance_tasks = cls.load_tasks(job_id=job.f_job_id, **filters)
  504. return source_inheritance_tasks, target_inheritance_tasks
  505. @classmethod
  506. def load_tasks(cls, component_list, job_id, **kwargs):
  507. tasks = JobSaver.query_task(job_id=job_id, only_latest=True, **kwargs)
  508. task_dict = {}
  509. for cpn in component_list:
  510. for task in tasks:
  511. if cpn == task.f_component_name:
  512. task_dict[f"{cpn}_{task.f_role}_{task.f_task_version}"] = task
  513. return task_dict
  514. @classmethod
  515. def load_task_tracker(cls, tasks: dict):
  516. tracker_dict = {}
  517. for key, task in tasks.items():
  518. schedule_logger(task.f_job_id).info(
  519. f"task:{task.f_job_id}, {task.f_role}, {task.f_party_id},{task.f_component_name},{task.f_task_version}")
  520. tracker = Tracker(job_id=task.f_job_id, role=task.f_role, party_id=task.f_party_id,
  521. component_name=task.f_component_name,
  522. task_id=task.f_task_id,
  523. task_version=task.f_task_version)
  524. tracker_dict[key] = tracker
  525. return tracker_dict
  526. @classmethod
  527. def log_reload(cls, job):
  528. schedule_logger(job.f_job_id).info("start reload job log")
  529. if job.f_inheritance_info:
  530. for component_name in job.f_inheritance_info.get("component_list"):
  531. source_path = os.path.join(log_utils.get_logger_base_dir(), job.f_inheritance_info.get("job_id"), job.f_role, job.f_party_id, component_name)
  532. target_path = os.path.join(log_utils.get_logger_base_dir(), job.f_job_id, job.f_role, job.f_party_id, component_name)
  533. if os.path.exists(source_path):
  534. if os.path.exists(target_path):
  535. shutil.rmtree(target_path)
  536. shutil.copytree(source_path, target_path)
  537. schedule_logger(job.f_job_id).info("reload job log success")
  538. @classmethod
  539. def output_reload(cls, job, source_tasks: dict, target_tasks: dict):
  540. # model reload
  541. schedule_logger(job.f_job_id).info("start reload model")
  542. source_jobs = JobSaver.query_job(job_id=job.f_inheritance_info["job_id"], role=job.f_role, party_id=job.f_party_id)
  543. if source_jobs:
  544. cls.output_model_reload(job, source_jobs[0])
  545. schedule_logger(job.f_job_id).info("start reload data")
  546. source_tracker_dict = cls.load_task_tracker(source_tasks)
  547. target_tracker_dict = cls.load_task_tracker(target_tasks)
  548. for key, source_tracker in source_tracker_dict.items():
  549. target_tracker = target_tracker_dict[key]
  550. table_infos = source_tracker.get_output_data_info()
  551. # data reload
  552. schedule_logger(job.f_job_id).info(f"table infos:{table_infos}")
  553. for table in table_infos:
  554. target_tracker.log_output_data_info(data_name=table.f_data_name,
  555. table_namespace=table.f_table_namespace,
  556. table_name=table.f_table_name)
  557. # cache reload
  558. schedule_logger(job.f_job_id).info("start reload cache")
  559. cache_list = source_tracker.query_output_cache_record()
  560. for cache in cache_list:
  561. schedule_logger(job.f_job_id).info(f"start reload cache name: {cache.f_cache_name}")
  562. target_tracker.tracking_output_cache(cache.f_cache, cache_name=cache.f_cache_name)
  563. # summary reload
  564. schedule_logger(job.f_job_id).info("start reload summary")
  565. target_tracker.reload_summary(source_tracker=source_tracker)
  566. # metric reload
  567. schedule_logger(job.f_job_id).info("start reload metric")
  568. target_tracker.reload_metric(source_tracker=source_tracker)
  569. schedule_logger(job.f_job_id).info("reload output success")
  570. @classmethod
  571. def status_reload(cls, job, source_tasks, target_tasks):
  572. schedule_logger(job.f_job_id).info("start reload status")
  573. # update task status
  574. for key, source_task in source_tasks.items():
  575. try:
  576. JobSaver.reload_task(source_task, target_tasks[key])
  577. except Exception as e:
  578. schedule_logger(job.f_job_id).warning(f"reload failed: {e}")
  579. # update job status
  580. JobSaver.update_job(job_info={
  581. "job_id": job.f_job_id,
  582. "role": job.f_role,
  583. "party_id": job.f_party_id,
  584. "inheritance_status": JobInheritanceStatus.SUCCESS,
  585. })
  586. schedule_logger(job.f_job_id).info("reload status success")
  587. @classmethod
  588. def output_model_reload(cls, job, source_job):
  589. source_pipelined_component = PipelinedComponent(
  590. role=source_job.f_role, party_id=source_job.f_party_id,
  591. model_id=source_job.f_runtime_conf['job_parameters']['common']['model_id'],
  592. model_version=source_job.f_job_id,
  593. )
  594. target_pipelined_component = PipelinedComponent(
  595. role=job.f_role, party_id=job.f_party_id,
  596. model_id=job.f_runtime_conf['job_parameters']['common']['model_id'],
  597. model_version=job.f_job_id,
  598. )
  599. query_args = (
  600. PipelineComponentMeta.f_component_name.in_(job.f_inheritance_info['component_list']),
  601. )
  602. query = source_pipelined_component.get_define_meta_from_db(*query_args)
  603. for row in query:
  604. for i in ('variables_data_path', 'run_parameters_path', 'checkpoint_path'):
  605. source_dir = getattr(source_pipelined_component, i) / row.f_component_name
  606. target_dir = getattr(target_pipelined_component, i) / row.f_component_name
  607. if not source_dir.is_dir():
  608. continue
  609. if target_dir.is_dir():
  610. shutil.rmtree(target_dir)
  611. shutil.copytree(source_dir, target_dir)
  612. source_pipelined_component.replicate_define_meta({
  613. 'f_role': target_pipelined_component.role,
  614. 'f_party_id': target_pipelined_component.party_id,
  615. 'f_model_id': target_pipelined_component.model_id,
  616. 'f_model_version': target_pipelined_component.model_version,
  617. }, query_args, True)