123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516 |
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import datetime
- import errno
- import os
- import random
- import sys
- import threading
- import typing
- from functools import wraps
- from fate_arch.common import FederatedMode, file_utils
- from fate_arch.common.base_utils import current_timestamp, fate_uuid, json_dumps
- from fate_flow.db.db_models import DB, Job, Task
- from fate_flow.db.db_utils import query_db
- from fate_flow.db.job_default_config import JobDefaultConfig
- from fate_flow.db.service_registry import ServerRegistry
- from fate_flow.entity import JobConfiguration, RunParameters
- from fate_flow.entity.run_status import JobStatus, TaskStatus
- from fate_flow.entity.types import InputSearchType
- from fate_flow.settings import FATE_BOARD_DASHBOARD_ENDPOINT
- from fate_flow.utils import data_utils, detect_utils, process_utils, session_utils
- from fate_flow.utils.base_utils import get_fate_flow_directory
- from fate_flow.utils.log_utils import schedule_logger
- from fate_flow.utils.schedule_utils import get_dsl_parser_by_version
- PIPELINE_COMPONENT_NAME = 'pipeline'
- PIPELINE_MODEL_ALIAS = 'pipeline'
- PIPELINE_COMPONENT_MODULE_NAME = 'Pipeline'
- PIPELINE_MODEL_NAME = 'Pipeline'
- class JobIdGenerator(object):
- _lock = threading.RLock()
- def __init__(self, initial_value=0):
- self._value = initial_value
- self._pre_timestamp = None
- self._max = 99999
- def next_id(self):
- """
- generate next job id with locking
- """
- #todo: there is duplication in the case of multiple instances deployment
- now = datetime.datetime.now()
- with JobIdGenerator._lock:
- if self._pre_timestamp == now:
- if self._value < self._max:
- self._value += 1
- else:
- now += datetime.timedelta(microseconds=1)
- self._pre_timestamp = now
- self._value = 0
- else:
- self._pre_timestamp = now
- self._value = 0
- return "{}{}".format(now.strftime("%Y%m%d%H%M%S%f"), self._value)
- job_id_generator = JobIdGenerator()
- def generate_job_id():
- return job_id_generator.next_id()
- def generate_task_id(job_id, component_name):
- return '{}_{}'.format(job_id, component_name)
- def generate_task_version_id(task_id, task_version):
- return "{}_{}".format(task_id, task_version)
- def generate_session_id(task_id, task_version, role, party_id, suffix=None, random_end=False):
- items = [task_id, str(task_version), role, str(party_id)]
- if suffix:
- items.append(suffix)
- if random_end:
- items.append(fate_uuid())
- return "_".join(items)
- def generate_task_input_data_namespace(task_id, task_version, role, party_id):
- return "input_data_{}".format(generate_session_id(task_id=task_id,
- task_version=task_version,
- role=role,
- party_id=party_id))
- def get_job_directory(job_id, *args):
- return os.path.join(get_fate_flow_directory(), 'jobs', job_id, *args)
- def get_job_log_directory(job_id, *args):
- return os.path.join(get_fate_flow_directory(), 'logs', job_id, *args)
- def get_task_directory(job_id, role, party_id, component_name, task_id, task_version, **kwargs):
- return get_job_directory(job_id, role, party_id, component_name, task_id, task_version)
- def get_general_worker_directory(worker_name, worker_id, *args):
- return os.path.join(get_fate_flow_directory(), worker_name, worker_id, *args)
- def get_general_worker_log_directory(worker_name, worker_id, *args):
- return os.path.join(get_fate_flow_directory(), 'logs', worker_name, worker_id, *args)
- def check_config(config: typing.Dict, required_parameters: typing.List):
- for parameter in required_parameters:
- if parameter not in config:
- return False, 'configuration no {} parameter'.format(parameter)
- else:
- return True, 'ok'
- def check_job_conf(runtime_conf, job_dsl):
- detect_utils.check_config(runtime_conf, ['initiator', 'role'])
- detect_utils.check_config(runtime_conf['initiator'], ['role', 'party_id'])
- # deal party id
- runtime_conf['initiator']['party_id'] = int(runtime_conf['initiator']['party_id'])
- for r in runtime_conf['role'].keys():
- for i in range(len(runtime_conf['role'][r])):
- runtime_conf['role'][r][i] = int(runtime_conf['role'][r][i])
- constraint_check(runtime_conf, job_dsl)
- def runtime_conf_basic(if_local=False):
- job_runtime_conf = {
- "dsl_version": 2,
- "initiator": {},
- "job_parameters": {
- "common": {
- "federated_mode": FederatedMode.SINGLE
- },
- },
- "role": {},
- "component_parameters": {}
- }
- if if_local:
- job_runtime_conf["initiator"]["role"] = "local"
- job_runtime_conf["initiator"]["party_id"] = 0
- job_runtime_conf["role"]["local"] = [0]
- return job_runtime_conf
- def new_runtime_conf(job_dir, method, module, role, party_id):
- if role:
- conf_path_dir = os.path.join(job_dir, method, module, role, str(party_id))
- else:
- conf_path_dir = os.path.join(job_dir, method, module, str(party_id))
- os.makedirs(conf_path_dir, exist_ok=True)
- return os.path.join(conf_path_dir, 'runtime_conf.json')
- def save_job_conf(job_id, role, party_id, dsl, runtime_conf, runtime_conf_on_party, train_runtime_conf, pipeline_dsl=None):
- path_dict = get_job_conf_path(job_id=job_id, role=role, party_id=party_id)
- dump_job_conf(path_dict=path_dict,
- dsl=dsl,
- runtime_conf=runtime_conf,
- runtime_conf_on_party=runtime_conf_on_party,
- train_runtime_conf=train_runtime_conf,
- pipeline_dsl=pipeline_dsl)
- return path_dict
- def save_task_using_job_conf(task: Task):
- task_dir = get_task_directory(job_id=task.f_job_id,
- role=task.f_role,
- party_id=task.f_party_id,
- component_name=task.f_component_name,
- task_id=task.f_task_id,
- task_version=str(task.f_task_version))
- return save_using_job_conf(task.f_job_id, task.f_role, task.f_party_id, config_dir=task_dir)
- def save_using_job_conf(job_id, role, party_id, config_dir):
- path_dict = get_job_conf_path(job_id=job_id, role=role, party_id=party_id, specified_dir=config_dir)
- job_configuration = get_job_configuration(job_id=job_id,
- role=role,
- party_id=party_id)
- dump_job_conf(path_dict=path_dict,
- dsl=job_configuration.dsl,
- runtime_conf=job_configuration.runtime_conf,
- runtime_conf_on_party=job_configuration.runtime_conf_on_party,
- train_runtime_conf=job_configuration.train_runtime_conf,
- pipeline_dsl=None)
- return path_dict
- def dump_job_conf(path_dict, dsl, runtime_conf, runtime_conf_on_party, train_runtime_conf, pipeline_dsl=None):
- os.makedirs(os.path.dirname(path_dict.get('dsl_path')), exist_ok=True)
- os.makedirs(os.path.dirname(path_dict.get('runtime_conf_on_party_path')), exist_ok=True)
- for data, conf_path in [(dsl, path_dict['dsl_path']),
- (runtime_conf, path_dict['runtime_conf_path']),
- (runtime_conf_on_party, path_dict['runtime_conf_on_party_path']),
- (train_runtime_conf, path_dict['train_runtime_conf_path']),
- (pipeline_dsl, path_dict['pipeline_dsl_path'])]:
- with open(conf_path, 'w+') as f:
- f.truncate()
- if not data:
- data = {}
- f.write(json_dumps(data, indent=4))
- f.flush()
- return path_dict
- @DB.connection_context()
- def get_job_configuration(job_id, role, party_id) -> JobConfiguration:
- jobs = Job.select(Job.f_dsl, Job.f_runtime_conf, Job.f_train_runtime_conf, Job.f_runtime_conf_on_party).where(Job.f_job_id == job_id,
- Job.f_role == role,
- Job.f_party_id == party_id)
- if jobs:
- job = jobs[0]
- return JobConfiguration(**job.to_human_model_dict())
- def get_task_using_job_conf(task_info: dict):
- task_dir = get_task_directory(**task_info)
- return read_job_conf(task_info["job_id"], task_info["role"], task_info["party_id"], task_dir)
- def read_job_conf(job_id, role, party_id, specified_dir=None):
- path_dict = get_job_conf_path(job_id=job_id, role=role, party_id=party_id, specified_dir=specified_dir)
- conf_dict = {}
- for key, path in path_dict.items():
- config = file_utils.load_json_conf(path)
- conf_dict[key.rstrip("_path")] = config
- return JobConfiguration(**conf_dict)
- def get_job_conf_path(job_id, role, party_id, specified_dir=None):
- conf_dir = get_job_directory(job_id) if not specified_dir else specified_dir
- job_dsl_path = os.path.join(conf_dir, 'job_dsl.json')
- job_runtime_conf_path = os.path.join(conf_dir, 'job_runtime_conf.json')
- if not specified_dir:
- job_runtime_conf_on_party_path = os.path.join(conf_dir, role, str(party_id), 'job_runtime_on_party_conf.json')
- else:
- job_runtime_conf_on_party_path = os.path.join(conf_dir, 'job_runtime_on_party_conf.json')
- train_runtime_conf_path = os.path.join(conf_dir, 'train_runtime_conf.json')
- pipeline_dsl_path = os.path.join(conf_dir, 'pipeline_dsl.json')
- return {'dsl_path': job_dsl_path,
- 'runtime_conf_path': job_runtime_conf_path,
- 'runtime_conf_on_party_path': job_runtime_conf_on_party_path,
- 'train_runtime_conf_path': train_runtime_conf_path,
- 'pipeline_dsl_path': pipeline_dsl_path}
- @DB.connection_context()
- def get_upload_job_configuration_summary(upload_tasks: typing.List[Task]):
- jobs_run_conf = {}
- for task in upload_tasks:
- jobs = Job.select(Job.f_job_id, Job.f_runtime_conf_on_party, Job.f_description).where(Job.f_job_id == task.f_job_id)
- job = jobs[0]
- jobs_run_conf[job.f_job_id] = job.f_runtime_conf_on_party["component_parameters"]["role"]["local"]["0"]["upload_0"]
- jobs_run_conf[job.f_job_id]["notes"] = job.f_description
- return jobs_run_conf
- @DB.connection_context()
- def get_job_parameters(job_id, role, party_id):
- jobs = Job.select(Job.f_runtime_conf_on_party).where(Job.f_job_id == job_id,
- Job.f_role == role,
- Job.f_party_id == party_id)
- if jobs:
- job = jobs[0]
- return job.f_runtime_conf_on_party.get("job_parameters")
- else:
- return {}
- @DB.connection_context()
- def get_job_dsl(job_id, role, party_id):
- jobs = Job.select(Job.f_dsl).where(Job.f_job_id == job_id,
- Job.f_role == role,
- Job.f_party_id == party_id)
- if jobs:
- job = jobs[0]
- return job.f_dsl
- else:
- return {}
- @DB.connection_context()
- def list_job(limit=0, offset=0, query=None, order_by=None):
- return query_db(Job, limit, offset, query, order_by)
- @DB.connection_context()
- def list_task(limit=0, offset=0, query=None, order_by=None):
- return query_db(Task, limit, offset, query, order_by)
- def check_job_process(pid):
- if pid < 0:
- return False
- if pid == 0:
- raise ValueError('invalid PID 0')
- try:
- os.kill(pid, 0)
- except OSError as err:
- if err.errno == errno.ESRCH:
- # ESRCH == No such process
- return False
- elif err.errno == errno.EPERM:
- # EPERM clearly means there's a process to deny access to
- return True
- else:
- # According to "man 2 kill" possible error values are
- # (EINVAL, EPERM, ESRCH)
- raise
- else:
- return True
- def check_job_is_timeout(job: Job):
- job_parameters = job.f_runtime_conf_on_party["job_parameters"]
- timeout = job_parameters.get("timeout", JobDefaultConfig.job_timeout)
- now_time = current_timestamp()
- running_time = (now_time - job.f_create_time)/1000
- if running_time > timeout:
- schedule_logger(job.f_job_id).info(f'run time {running_time}s timeout')
- return True
- else:
- return False
- def start_session_stop(task):
- job_parameters = RunParameters(**get_job_parameters(job_id=task.f_job_id, role=task.f_role, party_id=task.f_party_id))
- session_manager_id = generate_session_id(task.f_task_id, task.f_task_version, task.f_role, task.f_party_id)
- if task.f_status != TaskStatus.WAITING:
- schedule_logger(task.f_job_id).info(f'start run subprocess to stop task sessions {session_manager_id}')
- else:
- schedule_logger(task.f_job_id).info(f'task is waiting, pass stop sessions {session_manager_id}')
- return
- task_dir = os.path.join(get_job_directory(job_id=task.f_job_id), task.f_role,
- task.f_party_id, task.f_component_name, 'session_stop')
- os.makedirs(task_dir, exist_ok=True)
- process_cmd = [
- sys.executable or 'python3',
- sys.modules[session_utils.SessionStop.__module__].__file__,
- '--session', session_manager_id,
- '--computing', job_parameters.computing_engine,
- '--federation', job_parameters.federation_engine,
- '--storage', job_parameters.storage_engine,
- '-c', 'stop' if task.f_status == JobStatus.SUCCESS else 'kill'
- ]
- p = process_utils.run_subprocess(job_id=task.f_job_id, config_dir=task_dir, process_cmd=process_cmd)
- p.wait()
- p.poll()
- def get_timeout(job_id, timeout, runtime_conf, dsl):
- try:
- if timeout > 0:
- schedule_logger(job_id).info(f'setting job timeout {timeout}')
- return timeout
- else:
- default_timeout = job_default_timeout(runtime_conf, dsl)
- schedule_logger(job_id).info(f'setting job timeout {timeout} not a positive number, using the default timeout {default_timeout}')
- return default_timeout
- except:
- default_timeout = job_default_timeout(runtime_conf, dsl)
- schedule_logger(job_id).info(f'setting job timeout {timeout} is incorrect, using the default timeout {default_timeout}')
- return default_timeout
- def job_default_timeout(runtime_conf, dsl):
- # future versions will improve
- timeout = JobDefaultConfig.job_timeout
- return timeout
- def get_board_url(job_id, role, party_id):
- board_url = "http://{}:{}{}".format(
- ServerRegistry.FATEBOARD.get("host"),
- ServerRegistry.FATEBOARD.get("port"),
- FATE_BOARD_DASHBOARD_ENDPOINT).format(job_id, role, party_id)
- return board_url
- def check_job_inheritance_parameters(job, inheritance_jobs, inheritance_tasks):
- if not inheritance_jobs:
- raise Exception(
- f"no found job {job.f_inheritance_info.get('job_id')} role {job.f_role} party id {job.f_party_id}")
- inheritance_job = inheritance_jobs[0]
- task_status = {}
- for task in inheritance_tasks:
- task_status[task.f_component_name] = task.f_status
- for component in job.f_inheritance_info.get('component_list'):
- if component not in task_status.keys():
- raise Exception(f"job {job.f_inheritance_info.get('job_id')} no found component {component}")
- elif task_status[component] not in [TaskStatus.SUCCESS, TaskStatus.PASS]:
- raise Exception(F"job {job.f_inheritance_info.get('job_id')} component {component} status:{task_status[component]}")
- dsl_parser = get_dsl_parser_by_version()
- dsl_parser.verify_conf_reusability(inheritance_job.f_runtime_conf, job.f_runtime_conf, job.f_inheritance_info.get('component_list'))
- dsl_parser.verify_dsl_reusability(inheritance_job.f_dsl, job.f_dsl, job.f_inheritance_info.get('component_list', []))
- def get_job_all_components(dsl):
- return [dsl['components'][component_name]['module'].lower() for component_name in dsl['components'].keys()]
- def constraint_check(job_runtime_conf, job_dsl):
- if job_dsl:
- all_components = get_job_all_components(job_dsl)
- glm = ['heterolr', 'heterolinr', 'heteropoisson']
- for cpn in glm:
- if cpn in all_components:
- roles = job_runtime_conf.get('role')
- if 'guest' in roles.keys() and 'arbiter' in roles.keys() and 'host' in roles.keys():
- for party_id in set(roles['guest']) & set(roles['arbiter']):
- if party_id not in roles['host'] or len(set(roles['guest']) & set(roles['arbiter'])) != len(roles['host']):
- raise Exception("{} component constraint party id, please check role config:{}".format(cpn, job_runtime_conf.get('role')))
- def get_job_dataset(is_initiator, role, party_id, roles, job_args):
- dataset = {}
- dsl_version = 1
- if job_args.get('dsl_version'):
- if job_args.get('dsl_version') == 2:
- dsl_version = 2
- for _role, _role_party_args in job_args.items():
- if _role == "dsl_version":
- continue
- if is_initiator or _role == role:
- for _party_index in range(len(_role_party_args)):
- _party_id = roles[_role][_party_index]
- if is_initiator or _party_id == party_id:
- dataset[_role] = dataset.get(_role, {})
- dataset[_role][_party_id] = dataset[_role].get(
- _party_id, {})
- if dsl_version == 1:
- for _data_type, _data_location in _role_party_args[_party_index]['args']['data'].items():
- dataset[_role][_party_id][_data_type] = '{}.{}'.format(
- _data_location['namespace'], _data_location['name'])
- else:
- for key in _role_party_args[_party_index].keys():
- for _data_type, _data_location in _role_party_args[_party_index][key].items():
- search_type = data_utils.get_input_search_type(parameters=_data_location)
- if search_type is InputSearchType.TABLE_INFO:
- dataset[_role][_party_id][key] = '{}.{}'.format(_data_location['namespace'], _data_location['name'])
- elif search_type is InputSearchType.JOB_COMPONENT_OUTPUT:
- dataset[_role][_party_id][key] = '{}.{}.{}'.format(_data_location['job_id'], _data_location['component_name'], _data_location['data_name'])
- else:
- dataset[_role][_party_id][key] = "unknown"
- return dataset
- def asynchronous_function(func):
- @wraps(func)
- def _wrapper(*args, **kwargs):
- is_asynchronous = kwargs.pop("is_asynchronous", False)
- if is_asynchronous:
- thread = threading.Thread(target=func, args=args, kwargs=kwargs)
- thread.start()
- is_asynchronous = True
- return is_asynchronous
- else:
- return func(*args, **kwargs)
- return _wrapper
- def task_report(tasks):
- now_time = current_timestamp()
- report_list = [{"component_name": task.f_component_name, "start_time": task.f_start_time,
- "end_time": task.f_end_time, "elapsed": task.f_elapsed, "status": task.f_status}
- for task in tasks]
- report_list.sort(key=lambda x: (x["start_time"] if x["start_time"] else now_time, x["status"]))
- return report_list
- def get_component_parameters(job_providers, dsl_parser, provider_detail, role, party_id):
- component_parameters = dict()
- for component in job_providers.keys():
- provider_info = job_providers[component]["provider"]
- provider_name = provider_info["name"]
- provider_version = provider_info["version"]
- parameter = dsl_parser.parse_component_parameters(component,
- provider_detail,
- provider_name,
- provider_version,
- local_role=role,
- local_party_id=party_id)
- module_name = dsl_parser.get_component_info(component_name=component).get_module().lower()
- if module_name not in component_parameters.keys():
- component_parameters[module_name] = [parameter.get("ComponentParam", {})]
- else:
- component_parameters[module_name].append(parameter.get("ComponentParam", {}))
- return component_parameters
- def generate_retry_interval(cur_retry, max_retry_cnt, long_retry_cnt):
- if cur_retry < max_retry_cnt - long_retry_cnt:
- retry_interval = random.random() * 10 + 5
- else:
- retry_interval = round(300 + random.random() * 10, 3)
- return retry_interval
|