123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import importlib
- import os
- import sys
- import traceback
- from fate_arch import session, storage
- from fate_arch.common import EngineType, profile
- from fate_arch.common.base_utils import current_timestamp, json_dumps
- from fate_arch.computing import ComputingEngine
- from fate_flow.component_env_utils import provider_utils
- from fate_flow.db.component_registry import ComponentRegistry
- from fate_flow.db.db_models import TrackingOutputDataInfo, fill_db_model_object
- from fate_flow.db.runtime_config import RuntimeConfig
- from fate_flow.entity import DataCache, RunParameters
- from fate_flow.entity.run_status import TaskStatus
- from fate_flow.errors import PassError
- from fate_flow.hook import HookManager
- from fate_flow.manager.data_manager import DataTableTracker
- from fate_flow.manager.provider_manager import ProviderManager
- from fate_flow.model.checkpoint import CheckpointManager
- from fate_flow.operation.job_tracker import Tracker
- from fate_flow.scheduling_apps.client import TrackerClient
- from fate_flow.settings import ERROR_REPORT, ERROR_REPORT_WITH_PATH
- from fate_flow.utils import job_utils, schedule_utils
- from fate_flow.utils.base_utils import get_fate_flow_python_directory
- from fate_flow.utils.log_utils import getLogger, replace_ip
- from fate_flow.utils.model_utils import gen_party_model_id
- from fate_flow.worker.task_base_worker import BaseTaskWorker, ComponentInput
- LOGGER = getLogger()
- class TaskExecutor(BaseTaskWorker):
- def _run_(self):
- # todo: All function calls where errors should be thrown
- args = self.args
- start_time = current_timestamp()
- try:
- LOGGER.info(f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task')
- HookManager.init()
- self.report_info.update({
- "job_id": args.job_id,
- "component_name": args.component_name,
- "task_id": args.task_id,
- "task_version": args.task_version,
- "role": args.role,
- "party_id": args.party_id,
- "run_ip": args.run_ip,
- "run_port": args.run_port,
- "run_pid": self.run_pid
- })
- job_configuration = job_utils.get_job_configuration(
- job_id=self.args.job_id,
- role=self.args.role,
- party_id=self.args.party_id
- )
- task_parameters_conf = args.config
- dsl_parser = schedule_utils.get_job_dsl_parser(dsl=job_configuration.dsl,
- runtime_conf=job_configuration.runtime_conf,
- train_runtime_conf=job_configuration.train_runtime_conf,
- pipeline_dsl=None)
- job_parameters = dsl_parser.get_job_parameters(job_configuration.runtime_conf)
- user_name = job_parameters.get(args.role, {}).get(args.party_id, {}).get("user", '')
- LOGGER.info(f"user name:{user_name}")
- task_parameters = RunParameters(**task_parameters_conf)
- job_parameters = task_parameters
- if job_parameters.assistant_role:
- TaskExecutor.monkey_patch()
- job_args_on_party = TaskExecutor.get_job_args_on_party(dsl_parser, job_configuration.runtime_conf_on_party, args.role, args.party_id)
- component = dsl_parser.get_component_info(component_name=args.component_name)
- module_name = component.get_module()
- task_input_dsl = component.get_input()
- task_output_dsl = component.get_output()
- party_model_id = gen_party_model_id(job_parameters.model_id, args.role, args.party_id)
- model_version = job_parameters.model_version if job_parameters.job_type != 'predict' else args.job_id
- kwargs = {
- 'job_id': args.job_id,
- 'role': args.role,
- 'party_id': args.party_id,
- 'component_name': args.component_name,
- 'task_id': args.task_id,
- 'task_version': args.task_version,
- 'model_id': job_parameters.model_id,
- # in the prediction job, job_parameters.model_version comes from the training job
- # TODO: prediction job should not affect training job
- 'model_version': job_parameters.model_version,
- 'component_module_name': module_name,
- 'job_parameters': job_parameters,
- }
- tracker = Tracker(**kwargs)
- tracker_client = TrackerClient(**kwargs)
- checkpoint_manager = CheckpointManager(**kwargs)
- predict_tracker_client = None
- if job_parameters.job_type == 'predict':
- kwargs['model_version'] = model_version
- predict_tracker_client = TrackerClient(**kwargs)
- self.report_info["party_status"] = TaskStatus.RUNNING
- self.report_task_info_to_driver()
- previous_components_parameters = tracker_client.get_model_run_parameters()
- LOGGER.info(f"previous_components_parameters:\n{json_dumps(previous_components_parameters, indent=4)}")
- component_provider, component_parameters_on_party, user_specified_parameters = \
- ProviderManager.get_component_run_info(dsl_parser=dsl_parser, component_name=args.component_name,
- role=args.role, party_id=args.party_id,
- previous_components_parameters=previous_components_parameters)
- RuntimeConfig.set_component_provider(component_provider)
- LOGGER.info(f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}")
- flow_feeded_parameters = {"output_data_name": task_output_dsl.get("data")}
- # init environment, process is shared globally
- RuntimeConfig.init_config(COMPUTING_ENGINE=job_parameters.computing_engine,
- FEDERATION_ENGINE=job_parameters.federation_engine,
- FEDERATED_MODE=job_parameters.federated_mode)
- if RuntimeConfig.COMPUTING_ENGINE == ComputingEngine.EGGROLL:
- session_options = task_parameters.eggroll_run.copy()
- session_options["python.path"] = os.getenv("PYTHONPATH")
- session_options["python.venv"] = os.getenv("VIRTUAL_ENV")
- else:
- session_options = {}
- sess = session.Session(session_id=args.session_id)
- sess.as_global()
- sess.init_computing(computing_session_id=args.session_id, options=session_options)
- component_parameters_on_party["job_parameters"] = job_parameters.to_dict()
- roles = job_configuration.runtime_conf["role"]
- if set(roles) == {"local"}:
- LOGGER.info(f"only local roles, pass init federation")
- else:
- sess.init_federation(federation_session_id=args.federation_session_id,
- runtime_conf=component_parameters_on_party,
- service_conf=job_parameters.engines_address.get(EngineType.FEDERATION, {}))
- LOGGER.info(f'run {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} task')
- LOGGER.info(f"component parameters on party:\n{json_dumps(component_parameters_on_party, indent=4)}")
- LOGGER.info(f"task input dsl {task_input_dsl}")
- task_run_args, input_table_list = self.get_task_run_args(job_id=args.job_id, role=args.role, party_id=args.party_id,
- task_id=args.task_id,
- task_version=args.task_version,
- job_args=job_args_on_party,
- job_parameters=job_parameters,
- task_parameters=task_parameters,
- input_dsl=task_input_dsl,
- )
- if module_name in {"Upload", "Download", "Reader", "Writer", "Checkpoint"}:
- task_run_args["job_parameters"] = job_parameters
- # LOGGER.info(f"task input args {task_run_args}")
- need_run = component_parameters_on_party.get("ComponentParam", {}).get("need_run", True)
- provider_interface = provider_utils.get_provider_interface(provider=component_provider)
- run_object = provider_interface.get(module_name, ComponentRegistry.get_provider_components(provider_name=component_provider.name, provider_version=component_provider.version)).get_run_obj(self.args.role)
- flow_feeded_parameters.update({"table_info": input_table_list})
- cpn_input = ComponentInput(
- tracker=tracker_client,
- checkpoint_manager=checkpoint_manager,
- task_version_id=job_utils.generate_task_version_id(args.task_id, args.task_version),
- parameters=component_parameters_on_party["ComponentParam"],
- datasets=task_run_args.get("data", None),
- caches=task_run_args.get("cache", None),
- models=dict(
- model=task_run_args.get("model"),
- isometric_model=task_run_args.get("isometric_model"),
- ),
- job_parameters=job_parameters,
- roles=dict(
- role=component_parameters_on_party["role"],
- local=component_parameters_on_party["local"],
- ),
- flow_feeded_parameters=flow_feeded_parameters,
- )
- profile_log_enabled = False
- try:
- if int(os.getenv("FATE_PROFILE_LOG_ENABLED", "0")) > 0:
- profile_log_enabled = True
- except Exception as e:
- LOGGER.warning(e)
- if profile_log_enabled:
- # add profile logs
- LOGGER.info("profile logging is enabled")
- profile.profile_start()
- cpn_output = run_object.run(cpn_input)
- sess.wait_remote_all_done()
- profile.profile_ends()
- else:
- LOGGER.info("profile logging is disabled")
- cpn_output = run_object.run(cpn_input)
- sess.wait_remote_all_done()
- LOGGER.info(f"task output dsl {task_output_dsl}")
- LOGGER.info(f"task output data {cpn_output.data}")
- output_table_list = []
- for index, data in enumerate(cpn_output.data):
- data_name = task_output_dsl.get('data')[index] if task_output_dsl.get('data') else '{}'.format(index)
- #todo: the token depends on the engine type, maybe in job parameters
- persistent_table_namespace, persistent_table_name = tracker.save_output_data(
- computing_table=data,
- output_storage_engine=job_parameters.storage_engine,
- token={"username": user_name})
- if persistent_table_namespace and persistent_table_name:
- tracker.log_output_data_info(data_name=data_name,
- table_namespace=persistent_table_namespace,
- table_name=persistent_table_name)
- output_table_list.append({"namespace": persistent_table_namespace, "name": persistent_table_name})
- self.log_output_data_table_tracker(args.job_id, input_table_list, output_table_list)
- if cpn_output.model:
- getattr(
- tracker_client if predict_tracker_client is None else predict_tracker_client,
- 'save_component_output_model',
- )(
- model_buffers=cpn_output.model,
- # There is only one model output at the current dsl version
- model_alias=task_output_dsl['model'][0] if task_output_dsl.get('model') else 'default',
- user_specified_run_parameters=user_specified_parameters,
- )
- if cpn_output.cache:
- for i, cache in enumerate(cpn_output.cache):
- if cache is None:
- continue
- name = task_output_dsl.get("cache")[i] if "cache" in task_output_dsl else str(i)
- if isinstance(cache, DataCache):
- tracker.tracking_output_cache(cache, cache_name=name)
- elif isinstance(cache, tuple):
- tracker.save_output_cache(cache_data=cache[0],
- cache_meta=cache[1],
- cache_name=name,
- output_storage_engine=job_parameters.storage_engine,
- output_storage_address=job_parameters.engines_address.get(EngineType.STORAGE, {}),
- token={"username": user_name})
- else:
- raise RuntimeError(f"can not support type {type(cache)} module run object output cache")
- self.report_info["party_status"] = TaskStatus.SUCCESS if need_run else TaskStatus.PASS
- except PassError as e:
- self.report_info["party_status"] = TaskStatus.PASS
- except Exception as e:
- traceback.print_exc()
- self.report_info["party_status"] = TaskStatus.FAILED
- self.generate_error_report()
- LOGGER.exception(e)
- try:
- LOGGER.info("start destroy sessions")
- sess.destroy_all_sessions()
- LOGGER.info("destroy all sessions success")
- except Exception as e:
- LOGGER.exception(e)
- finally:
- try:
- self.report_info["end_time"] = current_timestamp()
- self.report_info["elapsed"] = self.report_info["end_time"] - start_time
- self.report_task_info_to_driver()
- except Exception as e:
- self.report_info["party_status"] = TaskStatus.FAILED
- traceback.print_exc()
- LOGGER.exception(e)
- msg = f"finish {args.component_name} {args.task_id} {args.task_version} on {args.role} {args.party_id} with {self.report_info['party_status']}"
- LOGGER.info(msg)
- print(msg)
- return self.report_info
- @classmethod
- def log_output_data_table_tracker(cls, job_id, input_table_list, output_table_list):
- try:
- parent_number = 0
- if len(input_table_list) > 1 and len(output_table_list)>1:
- # TODO
- return
- for input_table in input_table_list:
- for output_table in output_table_list:
- DataTableTracker.create_table_tracker(output_table.get("name"), output_table.get("namespace"),
- entity_info={
- "have_parent": True,
- "parent_table_namespace": input_table.get("namespace"),
- "parent_table_name": input_table.get("name"),
- "parent_number": parent_number,
- "job_id": job_id
- })
- parent_number +=1
- except Exception as e:
- LOGGER.exception(e)
- @classmethod
- def get_job_args_on_party(cls, dsl_parser, job_runtime_conf, role, party_id):
- party_index = job_runtime_conf["role"][role].index(int(party_id))
- job_args = dsl_parser.get_args_input()
- job_args_on_party = job_args[role][party_index].get('args') if role in job_args else {}
- return job_args_on_party
- @classmethod
- def get_task_run_args(cls, job_id, role, party_id, task_id, task_version,
- job_args, job_parameters: RunParameters, task_parameters: RunParameters,
- input_dsl, filter_type=None, filter_attr=None, get_input_table=False):
- task_run_args = {}
- input_table = {}
- input_table_info_list = []
- if 'idmapping' in role:
- return {}
- for input_type, input_detail in input_dsl.items():
- if filter_type and input_type not in filter_type:
- continue
- if input_type == 'data':
- this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
- for data_type, data_list in input_detail.items():
- data_dict = {}
- for data_key in data_list:
- data_key_item = data_key.split('.')
- data_dict[data_key_item[0]] = {data_type: []}
- for data_key in data_list:
- data_key_item = data_key.split('.')
- search_component_name, search_data_name = data_key_item[0], data_key_item[1]
- storage_table_meta = None
- tracker_client = TrackerClient(job_id=job_id, role=role, party_id=party_id,
- component_name=search_component_name,
- task_id=task_id, task_version=task_version)
- if search_component_name == 'args':
- if job_args.get('data', {}).get(search_data_name).get('namespace', '') and job_args.get(
- 'data', {}).get(search_data_name).get('name', ''):
- storage_table_meta = storage.StorageTableMeta(
- name=job_args['data'][search_data_name]['name'],
- namespace=job_args['data'][search_data_name]['namespace'])
- else:
- upstream_output_table_infos_json = tracker_client.get_output_data_info(
- data_name=search_data_name)
- if upstream_output_table_infos_json:
- tracker = Tracker(job_id=job_id, role=role, party_id=party_id,
- component_name=search_component_name,
- task_id=task_id, task_version=task_version)
- upstream_output_table_infos = []
- for _ in upstream_output_table_infos_json:
- upstream_output_table_infos.append(fill_db_model_object(
- Tracker.get_dynamic_db_model(TrackingOutputDataInfo, job_id)(), _))
- output_tables_meta = tracker.get_output_data_table(upstream_output_table_infos)
- if output_tables_meta:
- storage_table_meta = output_tables_meta.get(search_data_name, None)
- args_from_component = this_type_args[search_component_name] = this_type_args.get(
- search_component_name, {})
- if get_input_table and storage_table_meta:
- input_table[data_key] = {'namespace': storage_table_meta.get_namespace(),
- 'name': storage_table_meta.get_name()}
- computing_table = None
- elif storage_table_meta:
- LOGGER.info(f"load computing table use {task_parameters.computing_partitions}")
- computing_table = session.get_computing_session().load(
- storage_table_meta.get_address(),
- schema=storage_table_meta.get_schema(),
- partitions=task_parameters.computing_partitions)
- input_table_info_list.append({'namespace': storage_table_meta.get_namespace(),
- 'name': storage_table_meta.get_name()})
- else:
- computing_table = None
- if not computing_table or not filter_attr or not filter_attr.get("data", None):
- data_dict[search_component_name][data_type].append(computing_table)
- args_from_component[data_type] = data_dict[search_component_name][data_type]
- else:
- args_from_component[data_type] = dict(
- [(a, getattr(computing_table, "get_{}".format(a))()) for a in filter_attr["data"]])
- elif input_type == "cache":
- this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
- for search_key in input_detail:
- search_component_name, cache_name = search_key.split(".")
- tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=search_component_name)
- this_type_args[search_component_name] = tracker.get_output_cache(cache_name=cache_name)
- elif input_type in {'model', 'isometric_model'}:
- this_type_args = task_run_args[input_type] = task_run_args.get(input_type, {})
- for dsl_model_key in input_detail:
- dsl_model_key_items = dsl_model_key.split('.')
- if len(dsl_model_key_items) == 2:
- search_component_name, search_model_alias = dsl_model_key_items[0], dsl_model_key_items[1]
- elif len(dsl_model_key_items) == 3 and dsl_model_key_items[0] == 'pipeline':
- search_component_name, search_model_alias = dsl_model_key_items[1], dsl_model_key_items[2]
- else:
- raise Exception('get input {} failed'.format(input_type))
- kwargs = {
- 'job_id': job_id,
- 'role': role,
- 'party_id': party_id,
- 'component_name': search_component_name,
- 'model_id': job_parameters.model_id,
- # in the prediction job, job_parameters.model_version comes from the training job
- 'model_version': job_parameters.model_version,
- }
- # get models from the training job
- models = TrackerClient(**kwargs).read_component_output_model(search_model_alias)
- if not models and job_parameters.job_type == 'predict':
- kwargs['model_version'] = job_id
- # get models from the prediction job if not found in the training job
- models = TrackerClient(**kwargs).read_component_output_model(search_model_alias)
- this_type_args[search_component_name] = models
- else:
- raise Exception(f"not support {input_type} input type")
- if get_input_table:
- return input_table
- return task_run_args, input_table_info_list
- @classmethod
- def monkey_patch(cls):
- package_name = "monkey_patch"
- package_path = os.path.join(get_fate_flow_python_directory(), "fate_flow", package_name)
- if not os.path.exists(package_path):
- return
- for f in os.listdir(package_path):
- f_path = os.path.join(get_fate_flow_python_directory(), "fate_flow", package_name, f)
- if not os.path.isdir(f_path) or "__pycache__" in f_path:
- continue
- patch_module = importlib.import_module("fate_flow." + package_name + '.' + f + '.monkey_patch')
- patch_module.patch_all()
- def generate_error_report(self):
- if ERROR_REPORT:
- _error = ""
- etype, value, tb = sys.exc_info()
- path_list = os.getenv("PYTHONPATH").split(":")
- for line in traceback.TracebackException(type(value), value, tb).format(chain=True):
- if not ERROR_REPORT_WITH_PATH:
- for path in path_list:
- line = line.replace(path, "xxx")
- line = replace_ip(line)
- _error += line
- self.report_info["error_report"] = _error.rstrip("\n")
- # this file may not be running on the same machine as fate_flow,
- # so we need to use the tracker to get the input and save the output
- if __name__ == '__main__':
- worker = TaskExecutor()
- worker.run()
- worker.report_task_info_to_driver()
|