task_executor.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import os
  18. import sys
  19. import traceback
  20. from fate_arch import session, storage
  21. from fate_arch.common import EngineType, profile
  22. from fate_arch.common.base_utils import current_timestamp, json_dumps
  23. from fate_arch.computing import ComputingEngine
  24. from fate_flow.component_env_utils import provider_utils
  25. from fate_flow.db.component_registry import ComponentRegistry
  26. from fate_flow.db.db_models import TrackingOutputDataInfo, fill_db_model_object
  27. from fate_flow.db.runtime_config import RuntimeConfig
  28. from fate_flow.entity import DataCache, RunParameters
  29. from fate_flow.entity.run_status import TaskStatus
  30. from fate_flow.errors import PassError
  31. from fate_flow.hook import HookManager
  32. from fate_flow.manager.data_manager import DataTableTracker
  33. from fate_flow.manager.provider_manager import ProviderManager
  34. from fate_flow.model.checkpoint import CheckpointManager
  35. from fate_flow.operation.job_tracker import Tracker
  36. from fate_flow.scheduling_apps.client import TrackerClient
  37. from fate_flow.settings import ERROR_REPORT, ERROR_REPORT_WITH_PATH
  38. from fate_flow.utils import job_utils, schedule_utils
  39. from fate_flow.utils.base_utils import get_fate_flow_python_directory
  40. from fate_flow.utils.log_utils import getLogger, replace_ip
  41. from fate_flow.utils.model_utils import gen_party_model_id
  42. from fate_flow.worker.task_base_worker import BaseTaskWorker, ComponentInput
  43. LOGGER = getLogger()
  44. class TaskExecutor(BaseTaskWorker):
  45. def _run_(self):
  46. # todo: All function calls where errors should be thrown
  47. args = self.args
  48. start_time = current_timestamp()
  49. try:
  50. LOGGER.info(f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task')
  51. HookManager.init()
  52. self.report_info.update({
  53. "job_id": args.job_id,
  54. "component_name": args.component_name,
  55. "task_id": args.task_id,
  56. "task_version": args.task_version,
  57. "role": args.role,
  58. "party_id": args.party_id,
  59. "run_ip": args.run_ip,
  60. "run_port": args.run_port,
  61. "run_pid": self.run_pid
  62. })
  63. job_configuration = job_utils.get_job_configuration(
  64. job_id=self.args.job_id,
  65. role=self.args.role,
  66. party_id=self.args.party_id
  67. )
  68. task_parameters_conf = args.config
  69. dsl_parser = schedule_utils.get_job_dsl_parser(dsl=job_configuration.dsl,
  70. runtime_conf=job_configuration.runtime_conf,
  71. train_runtime_conf=job_configuration.train_runtime_conf,
  72. pipeline_dsl=None)
  73. job_parameters = dsl_parser.get_job_parameters(job_configuration.runtime_conf)
  74. user_name = job_parameters.get(args.role, {}).get(args.party_id, {}).get("user", '')
  75. LOGGER.info(f"user name:{user_name}")
  76. task_parameters = RunParameters(**task_parameters_conf)
  77. job_parameters = task_parameters
  78. if job_parameters.assistant_role:
  79. TaskExecutor.monkey_patch()
  80. job_args_on_party = TaskExecutor.get_job_args_on_party(dsl_parser, job_configuration.runtime_conf_on_party, args.role, args.party_id)
  81. component = dsl_parser.get_component_info(component_name=args.component_name)
  82. module_name = component.get_module()
  83. task_input_dsl = component.get_input()
  84. task_output_dsl = component.get_output()
  85. party_model_id = gen_party_model_id(job_parameters.model_id, args.role, args.party_id)
  86. model_version = job_parameters.model_version if job_parameters.job_type != 'predict' else args.job_id
  87. kwargs = {
  88. 'job_id': args.job_id,
  89. 'role': args.role,
  90. 'party_id': args.party_id,
  91. 'component_name': args.component_name,
  92. 'task_id': args.task_id,
  93. 'task_version': args.task_version,
  94. 'model_id': job_parameters.model_id,
  95. # in the prediction job, job_parameters.model_version comes from the training job
  96. # TODO: prediction job should not affect training job
  97. 'model_version': job_parameters.model_version,
  98. 'component_module_name': module_name,
  99. 'job_parameters': job_parameters,
  100. }
  101. tracker = Tracker(**kwargs)
  102. tracker_client = TrackerClient(**kwargs)
  103. checkpoint_manager = CheckpointManager(**kwargs)
  104. predict_tracker_client = None
  105. if job_parameters.job_type == 'predict':
  106. kwargs['model_version'] = model_version
  107. predict_tracker_client = TrackerClient(**kwargs)
  108. self.report_info["party_status"] = TaskStatus.RUNNING
  109. self.report_task_info_to_driver()
  110. previous_components_parameters = tracker_client.get_model_run_parameters()
  111. LOGGER.info(f"previous_components_parameters:\n{json_dumps(previous_components_parameters, indent=4)}")
  112. component_provider, component_parameters_on_party, user_specified_parameters = \
  113. ProviderManager.get_component_run_info(dsl_parser=dsl_parser, component_name=args.component_name,
  114. role=args.role, party_id=args.party_id,
  115. previous_components_parameters=previous_components_parameters)
  116. RuntimeConfig.set_component_provider(component_provider)
  117. LOGGER.info(f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}")
  118. flow_feeded_parameters = {"output_data_name": task_output_dsl.get("data")}
  119. # init environment, process is shared globally
  120. RuntimeConfig.init_config(COMPUTING_ENGINE=job_parameters.computing_engine,
  121. FEDERATION_ENGINE=job_parameters.federation_engine,
  122. FEDERATED_MODE=job_parameters.federated_mode)
  123. if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
  124. session_options = task_parameters.eggroll_run.copy()
  125. session_options["python.path"] = os.getenv("PYTHONPATH")
  126. session_options["python.venv"] = os.getenv("VIRTUAL_ENV")
  127. else:
  128. session_options = {}
  129. sess = session.Session(session_id=args.session_id)
  130. sess.as_global()
  131. sess.init_computing(computing_session_id=args.session_id, options=session_options)
  132. component_parameters_on_party["job_parameters"] = job_parameters.to_dict()
  133. roles = job_configuration.runtime_conf["role"]
  134. if set(roles) == {"local"}:
  135. LOGGER.info(f"only local roles, pass init federation")
  136. else:
  137. sess.init_federation(federation_session_id=args.federation_session_id,
  138. runtime_conf=component_parameters_on_party,
  139. service_conf=job_parameters.engines_address.get(EngineType.FEDERATION, {}))
  140. LOGGER.info(f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task')
  141. LOGGER.info(f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}")
  142. LOGGER.info(f"task input dsl {task_input_dsl}")
  143. task_run_args, input_table_list = self.get_task_run_args(job_id=args.job_id, role=args.role, party_id=args.party_id,
  144. task_id=args.task_id,
  145. task_version=args.task_version,
  146. job_args=job_args_on_party,
  147. job_parameters=job_parameters,
  148. task_parameters=task_parameters,
  149. input_dsl=task_input_dsl,
  150. )
  151. if module_name in {"Upload", "Download", "Reader", "Writer", "Checkpoint"}:
  152. task_run_args["job_parameters"] = job_parameters
  153. # LOGGER.info(f"task input args {task_run_args}")
  154. need_run = component_parameters_on_party.get("ComponentParam", {}).get("need_run", True)
  155. provider_interface = provider_utils.get_provider_interface(provider=component_provider)
  156. run_object = provider_interface.get(module_name, ComponentRegistry.get_provider_components(provider_name=component_provider.name, provider_version=component_provider.version)).get_run_obj(self.args.role)
  157. flow_feeded_parameters.update({"table_info": input_table_list})
  158. cpn_input = ComponentInput(
  159. tracker=tracker_client,
  160. checkpoint_manager=checkpoint_manager,
  161. task_version_id=job_utils.generate_task_version_id(args.task_id, args.task_version),
  162. parameters=component_parameters_on_party["ComponentParam"],
  163. datasets=task_run_args.get("data", None),
  164. caches=task_run_args.get("cache", None),
  165. models=dict(
  166. model=task_run_args.get("model"),
  167. isometric_model=task_run_args.get("isometric_model"),
  168. ),
  169. job_parameters=job_parameters,
  170. roles=dict(
  171. role=component_parameters_on_party["role"],
  172. local=component_parameters_on_party["local"],
  173. ),
  174. flow_feeded_parameters=flow_feeded_parameters,
  175. )
  176. profile_log_enabled = False
  177. try:
  178. if int(os.getenv("FATE_PROFILE_LOG_ENABLED", "0")) > 0:
  179. profile_log_enabled = True
  180. except Exception as e:
  181. LOGGER.warning(e)
  182. if profile_log_enabled:
  183. # add profile logs
  184. LOGGER.info("profile logging is enabled")
  185. profile.profile_start()
  186. cpn_output = run_object.run(cpn_input)
  187. sess.wait_remote_all_done()
  188. profile.profile_ends()
  189. else:
  190. LOGGER.info("profile logging is disabled")
  191. cpn_output = run_object.run(cpn_input)
  192. sess.wait_remote_all_done()
  193. LOGGER.info(f"task output dsl {task_output_dsl}")
  194. LOGGER.info(f"task output data {cpn_output.data}")
  195. output_table_list = []
  196. for index, data in enumerate(cpn_output.data):
  197. data_name = task_output_dsl.get('data')[index] if task_output_dsl.get('data') else '{}'.format(index)
  198. #todo: the token depends on the engine type, maybe in job parameters
  199. persistent_table_namespace, persistent_table_name = tracker.save_output_data(
  200. computing_table=data,
  201. output_storage_engine=job_parameters.storage_engine,
  202. token={"username": user_name})
  203. if persistent_table_namespace and persistent_table_name:
  204. tracker.log_output_data_info(data_name=data_name,
  205. table_namespace=persistent_table_namespace,
  206. table_name=persistent_table_name)
  207. output_table_list.append({"namespace": persistent_table_namespace, "name": persistent_table_name})
  208. self.log_output_data_table_tracker(args.job_id, input_table_list, output_table_list)
  209. if cpn_output.model:
  210. getattr(
  211. tracker_client if predict_tracker_client is None else predict_tracker_client,
  212. 'save_component_output_model',
  213. )(
  214. model_buffers=cpn_output.model,
  215. # There is only one model output at the current dsl version
  216. model_alias=task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default',
  217. user_specified_run_parameters=user_specified_parameters,
  218. )
  219. if cpn_output.cache:
  220. for i, cache in enumerate(cpn_output.cache):
  221. if cache is None:
  222. continue
  223. name = task_output_dsl.get("cache")[i] if "cache" in task_output_dsl else str(i)
  224. if isinstance(cache, DataCache):
  225. tracker.tracking_output_cache(cache, cache_name=name)
  226. elif isinstance(cache, tuple):
  227. tracker.save_output_cache(cache_data=cache[0],
  228. cache_meta=cache[1],
  229. cache_name=name,
  230. output_storage_engine=job_parameters.storage_engine,
  231. output_storage_address=job_parameters.engines_address.get(EngineType.STORAGE, {}),
  232. token={"username": user_name})
  233. else:
  234. raise RuntimeError(f"can not support type {type(cache)} module run object output cache")
  235. self.report_info["party_status"] = TaskStatus.SUCCESS if need_run else TaskStatus.PASS
  236. except PassError as e:
  237. self.report_info["party_status"] = TaskStatus.PASS
  238. except Exception as e:
  239. traceback.print_exc()
  240. self.report_info["party_status"] = TaskStatus.FAILED
  241. self.generate_error_report()
  242. LOGGER.exception(e)
  243. try:
  244. LOGGER.info("start destroy sessions")
  245. sess.destroy_all_sessions()
  246. LOGGER.info("destroy all sessions success")
  247. except Exception as e:
  248. LOGGER.exception(e)
  249. finally:
  250. try:
  251. self.report_info["end_time"] = current_timestamp()
  252. self.report_info["elapsed"] = self.report_info["end_time"] - start_time
  253. self.report_task_info_to_driver()
  254. except Exception as e:
  255. self.report_info["party_status"] = TaskStatus.FAILED
  256. traceback.print_exc()
  257. LOGGER.exception(e)
  258. msg = f"finish {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} with {self.report_info['party_status']}"
  259. LOGGER.info(msg)
  260. print(msg)
  261. return self.report_info
  262. @classmethod
  263. def log_output_data_table_tracker(cls, job_id, input_table_list, output_table_list):
  264. try:
  265. parent_number = 0
  266. if len(input_table_list) > 1 and len(output_table_list)>1:
  267. # TODO
  268. return
  269. for input_table in input_table_list:
  270. for output_table in output_table_list:
  271. DataTableTracker.create_table_tracker(output_table.get("name"), output_table.get("namespace"),
  272. entity_info={
  273. "have_parent": True,
  274. "parent_table_namespace": input_table.get("namespace"),
  275. "parent_table_name": input_table.get("name"),
  276. "parent_number": parent_number,
  277. "job_id": job_id
  278. })
  279. parent_number +=1
  280. except Exception as e:
  281. LOGGER.exception(e)
  282. @classmethod
  283. def get_job_args_on_party(cls, dsl_parser, job_runtime_conf, role, party_id):
  284. party_index = job_runtime_conf["role"][role].index(int(party_id))
  285. job_args = dsl_parser.get_args_input()
  286. job_args_on_party = job_args[role][party_index].get('args') if role in job_args else {}
  287. return job_args_on_party
  288. @classmethod
  289. def get_task_run_args(cls, job_id, role, party_id, task_id, task_version,
  290. job_args, job_parameters: RunParameters, task_parameters: RunParameters,
  291. input_dsl, filter_type=None, filter_attr=None, get_input_table=False):
  292. task_run_args = {}
  293. input_table = {}
  294. input_table_info_list = []
  295. if 'idmapping' in role:
  296. return {}
  297. for input_type, input_detail in input_dsl.items():
  298. if filter_type and input_type not in filter_type:
  299. continue
  300. if input_type == 'data':
  301. this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
  302. for data_type, data_list in input_detail.items():
  303. data_dict = {}
  304. for data_key in data_list:
  305. data_key_item = data_key.split('.')
  306. data_dict[data_key_item[0]] = {data_type: []}
  307. for data_key in data_list:
  308. data_key_item = data_key.split('.')
  309. search_component_name, search_data_name = data_key_item[0], data_key_item[1]
  310. storage_table_meta = None
  311. tracker_client = TrackerClient(job_id=job_id, role=role, party_id=party_id,
  312. component_name=search_component_name,
  313. task_id=task_id, task_version=task_version)
  314. if search_component_name == 'args':
  315. if job_args.get('data', {}).get(search_data_name).get('namespace', '') and job_args.get(
  316. 'data', {}).get(search_data_name).get('name', ''):
  317. storage_table_meta = storage.StorageTableMeta(
  318. name=job_args['data'][search_data_name]['name'],
  319. namespace=job_args['data'][search_data_name]['namespace'])
  320. else:
  321. upstream_output_table_infos_json = tracker_client.get_output_data_info(
  322. data_name=search_data_name)
  323. if upstream_output_table_infos_json:
  324. tracker = Tracker(job_id=job_id, role=role, party_id=party_id,
  325. component_name=search_component_name,
  326. task_id=task_id, task_version=task_version)
  327. upstream_output_table_infos = []
  328. for _ in upstream_output_table_infos_json:
  329. upstream_output_table_infos.append(fill_db_model_object(
  330. Tracker.get_dynamic_db_model(TrackingOutputDataInfo, job_id)(), _))
  331. output_tables_meta = tracker.get_output_data_table(upstream_output_table_infos)
  332. if output_tables_meta:
  333. storage_table_meta = output_tables_meta.get(search_data_name, None)
  334. args_from_component = this_type_args[search_component_name] = this_type_args.get(
  335. search_component_name, {})
  336. if get_input_table and storage_table_meta:
  337. input_table[data_key] = {'namespace': storage_table_meta.get_namespace(),
  338. 'name': storage_table_meta.get_name()}
  339. computing_table = None
  340. elif storage_table_meta:
  341. LOGGER.info(f"load computing table use {task_parameters.computing_partitions}")
  342. computing_table = session.get_computing_session().load(
  343. storage_table_meta.get_address(),
  344. schema=storage_table_meta.get_schema(),
  345. partitions=task_parameters.computing_partitions)
  346. input_table_info_list.append({'namespace': storage_table_meta.get_namespace(),
  347. 'name': storage_table_meta.get_name()})
  348. else:
  349. computing_table = None
  350. if not computing_table or not filter_attr or not filter_attr.get("data", None):
  351. data_dict[search_component_name][data_type].append(computing_table)
  352. args_from_component[data_type] = data_dict[search_component_name][data_type]
  353. else:
  354. args_from_component[data_type] = dict(
  355. [(a, getattr(computing_table, "get_{}".format(a))()) for a in filter_attr["data"]])
  356. elif input_type == "cache":
  357. this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
  358. for search_key in input_detail:
  359. search_component_name, cache_name = search_key.split(".")
  360. tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=search_component_name)
  361. this_type_args[search_component_name] = tracker.get_output_cache(cache_name=cache_name)
  362. elif input_type in {'model', 'isometric_model'}:
  363. this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
  364. for dsl_model_key in input_detail:
  365. dsl_model_key_items = dsl_model_key.split('.')
  366. if len(dsl_model_key_items) == 2:
  367. search_component_name, search_model_alias = dsl_model_key_items[0], dsl_model_key_items[1]
  368. elif len(dsl_model_key_items) == 3 and dsl_model_key_items[0] == 'pipeline':
  369. search_component_name, search_model_alias = dsl_model_key_items[1], dsl_model_key_items[2]
  370. else:
  371. raise Exception('get input {} failed'.format(input_type))
  372. kwargs = {
  373. 'job_id': job_id,
  374. 'role': role,
  375. 'party_id': party_id,
  376. 'component_name': search_component_name,
  377. 'model_id': job_parameters.model_id,
  378. # in the prediction job, job_parameters.model_version comes from the training job
  379. 'model_version': job_parameters.model_version,
  380. }
  381. # get models from the training job
  382. models = TrackerClient(**kwargs).read_component_output_model(search_model_alias)
  383. if not models and job_parameters.job_type == 'predict':
  384. kwargs['model_version'] = job_id
  385. # get models from the prediction job if not found in the training job
  386. models = TrackerClient(**kwargs).read_component_output_model(search_model_alias)
  387. this_type_args[search_component_name] = models
  388. else:
  389. raise Exception(f"not support {input_type} input type")
  390. if get_input_table:
  391. return input_table
  392. return task_run_args, input_table_info_list
  393. @classmethod
  394. def monkey_patch(cls):
  395. package_name = "monkey_patch"
  396. package_path = os.path.join(get_fate_flow_python_directory(), "fate_flow", package_name)
  397. if not os.path.exists(package_path):
  398. return
  399. for f in os.listdir(package_path):
  400. f_path = os.path.join(get_fate_flow_python_directory(), "fate_flow", package_name, f)
  401. if not os.path.isdir(f_path) or "__pycache__" in f_path:
  402. continue
  403. patch_module = importlib.import_module("fate_flow." + package_name + '.' + f + '.monkey_patch')
  404. patch_module.patch_all()
  405. def generate_error_report(self):
  406. if ERROR_REPORT:
  407. _error = ""
  408. etype, value, tb = sys.exc_info()
  409. path_list = os.getenv("PYTHONPATH").split(":")
  410. for line in traceback.TracebackException(type(value), value, tb).format(chain=True):
  411. if not ERROR_REPORT_WITH_PATH:
  412. for path in path_list:
  413. line = line.replace(path, "xxx")
  414. line = replace_ip(line)
  415. _error += line
  416. self.report_info["error_report"] = _error.rstrip("\n")
  417. # this file may not be running on the same machine as fate_flow,
  418. # so we need to use the tracker to get the input and save the output
  419. if __name__ == '__main__':
  420. worker = TaskExecutor()
  421. worker.run()
  422. worker.report_task_info_to_driver()