123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import os
- import subprocess
- import sys
- from uuid import uuid1
- from fate_arch.common.base_utils import current_timestamp, json_dumps
- from fate_arch.common.file_utils import load_json_conf
- from fate_arch.metastore.base_model import auto_date_timestamp_db_field
- from fate_flow.db.db_models import DB, Task, WorkerInfo
- from fate_flow.db.runtime_config import RuntimeConfig
- from fate_flow.entity import ComponentProvider, RunParameters
- from fate_flow.entity.types import WorkerName
- from fate_flow.settings import stat_logger
- from fate_flow.utils import job_utils, process_utils
- from fate_flow.utils.log_utils import failed_log, ready_log, schedule_logger, start_log, successful_log
- class WorkerManager:
- @classmethod
- def start_general_worker(cls, worker_name: WorkerName, job_id="", role="", party_id=0, provider: ComponentProvider = None,
- initialized_config: dict = None, run_in_subprocess=True, **kwargs):
- if RuntimeConfig.DEBUG:
- run_in_subprocess = True
- participate = locals()
- worker_id, config_dir, log_dir = cls.get_process_dirs(worker_name=worker_name,
- job_id=job_id,
- role=role,
- party_id=party_id)
- if worker_name in [WorkerName.PROVIDER_REGISTRAR, WorkerName.DEPENDENCE_UPLOAD]:
- if not provider:
- raise ValueError("no provider argument")
- config = {
- "provider": provider.to_dict()
- }
- if worker_name == WorkerName.PROVIDER_REGISTRAR:
- from fate_flow.worker.provider_registrar import ProviderRegistrar
- module = ProviderRegistrar
- module_file_path = sys.modules[ProviderRegistrar.__module__].__file__
- specific_cmd = []
- elif worker_name == WorkerName.DEPENDENCE_UPLOAD:
- from fate_flow.worker.dependence_upload import DependenceUpload
- module = DependenceUpload
- module_file_path = sys.modules[DependenceUpload.__module__].__file__
- specific_cmd = [
- '--dependence_type', kwargs.get("dependence_type")
- ]
- provider_info = provider.to_dict()
- elif worker_name is WorkerName.TASK_INITIALIZER:
- if not initialized_config:
- raise ValueError("no initialized_config argument")
- config = initialized_config
- from fate_flow.worker.task_initializer import TaskInitializer
- module = TaskInitializer
- module_file_path = sys.modules[TaskInitializer.__module__].__file__
- specific_cmd = []
- provider_info = initialized_config["provider"]
- else:
- raise Exception(f"not support {worker_name} worker")
- config_path, result_path = cls.get_config(config_dir=config_dir, config=config, log_dir=log_dir)
- process_cmd = [
- sys.executable or "python3",
- module_file_path,
- "--config", config_path,
- '--result', result_path,
- "--log_dir", log_dir,
- "--parent_log_dir", os.path.dirname(log_dir),
- "--worker_id", worker_id,
- "--run_ip", RuntimeConfig.JOB_SERVER_HOST,
- "--job_server", f"{RuntimeConfig.JOB_SERVER_HOST}:{RuntimeConfig.HTTP_PORT}",
- ]
- if job_id:
- process_cmd.extend([
- "--job_id", job_id,
- "--role", role,
- "--party_id", party_id,
- ])
- process_cmd.extend(specific_cmd)
- if run_in_subprocess:
- p = process_utils.run_subprocess(job_id=job_id, config_dir=config_dir, process_cmd=process_cmd,
- added_env=cls.get_env(job_id, provider_info), log_dir=log_dir,
- cwd_dir=config_dir, process_name=worker_name.value, process_id=worker_id)
- participate["pid"] = p.pid
- if job_id and role and party_id:
- logger = schedule_logger(job_id)
- msg = f"{worker_name} worker {worker_id} subprocess {p.pid}"
- else:
- logger = stat_logger
- msg = f"{worker_name} worker {worker_id} subprocess {p.pid}"
- logger.info(ready_log(msg=msg, role=role, party_id=party_id))
- # asynchronous
- if worker_name in [WorkerName.DEPENDENCE_UPLOAD]:
- if kwargs.get("callback") and kwargs.get("callback_param"):
- callback_param = {}
- participate.update(participate.get("kwargs", {}))
- for k, v in participate.items():
- if k in kwargs.get("callback_param"):
- callback_param[k] = v
- kwargs.get("callback")(**callback_param)
- else:
- try:
- p.wait(timeout=120)
- if p.returncode == 0:
- logger.info(successful_log(msg=msg, role=role, party_id=party_id))
- else:
- logger.info(failed_log(msg=msg, role=role, party_id=party_id))
- if p.returncode == 0:
- return p.returncode, load_json_conf(result_path)
- else:
- std_path = process_utils.get_std_path(log_dir=log_dir, process_name=worker_name.value, process_id=worker_id)
- raise Exception(f"run error, please check logs: {std_path}, {log_dir}/INFO.log")
- except subprocess.TimeoutExpired as e:
- err = failed_log(msg=f"{msg} run timeout", role=role, party_id=party_id)
- logger.exception(err)
- raise Exception(err)
- finally:
- try:
- p.kill()
- p.poll()
- except Exception as e:
- logger.exception(e)
- else:
- kwargs = cls.cmd_to_func_kwargs(process_cmd)
- code, message, result = module().run(**kwargs)
- if code == 0:
- return code, result
- else:
- raise Exception(message)
- @classmethod
- def start_task_worker(cls, worker_name, task: Task, task_parameters: RunParameters = None,
- executable: list = None, extra_env: dict = None, **kwargs):
- worker_id, config_dir, log_dir = cls.get_process_dirs(worker_name=worker_name,
- job_id=task.f_job_id,
- role=task.f_role,
- party_id=task.f_party_id,
- task=task)
- session_id = job_utils.generate_session_id(task.f_task_id, task.f_task_version, task.f_role, task.f_party_id)
- federation_session_id = job_utils.generate_task_version_id(task.f_task_id, task.f_task_version)
- info_kwargs = {}
- specific_cmd = []
- if worker_name is WorkerName.TASK_EXECUTOR:
- from fate_flow.worker.task_executor import TaskExecutor
- module_file_path = sys.modules[TaskExecutor.__module__].__file__
- else:
- raise Exception(f"not support {worker_name} worker")
- if task_parameters is None:
- task_parameters = RunParameters(**job_utils.get_job_parameters(task.f_job_id, task.f_role, task.f_party_id))
- config = task_parameters.to_dict()
- config["src_user"] = kwargs.get("src_user")
- config_path, result_path = cls.get_config(config_dir=config_dir, config=config, log_dir=log_dir)
- env = cls.get_env(task.f_job_id, task.f_provider_info)
- if executable:
- process_cmd = executable
- else:
- process_cmd = [env.get("PYTHON_ENV") or sys.executable or "python3"]
- common_cmd = [
- module_file_path,
- "--job_id", task.f_job_id,
- "--component_name", task.f_component_name,
- "--task_id", task.f_task_id,
- "--task_version", task.f_task_version,
- "--role", task.f_role,
- "--party_id", task.f_party_id,
- "--config", config_path,
- '--result', result_path,
- "--log_dir", log_dir,
- "--parent_log_dir", os.path.dirname(log_dir),
- "--worker_id", worker_id,
- "--run_ip", RuntimeConfig.JOB_SERVER_HOST,
- "--run_port", RuntimeConfig.HTTP_PORT,
- "--job_server", f"{RuntimeConfig.JOB_SERVER_HOST}:{RuntimeConfig.HTTP_PORT}",
- "--session_id", session_id,
- "--federation_session_id", federation_session_id
- ]
- process_cmd.extend(common_cmd)
- process_cmd.extend(specific_cmd)
- if extra_env:
- env.update(extra_env)
- schedule_logger(task.f_job_id).info(
- f"task {task.f_task_id} {task.f_task_version} on {task.f_role} {task.f_party_id} {worker_name} worker subprocess is ready")
- p = process_utils.run_subprocess(job_id=task.f_job_id, config_dir=config_dir, process_cmd=process_cmd,
- added_env=env, log_dir=log_dir, cwd_dir=config_dir, process_name=worker_name.value,
- process_id=worker_id)
- cls.save_worker_info(task=task, worker_name=worker_name, worker_id=worker_id, run_ip=RuntimeConfig.JOB_SERVER_HOST, run_pid=p.pid, config=config, cmd=process_cmd, **info_kwargs)
- return {"run_pid": p.pid, "worker_id": worker_id, "cmd": process_cmd}
- @classmethod
- def get_process_dirs(cls, worker_name: WorkerName, job_id=None, role=None, party_id=None, task: Task = None):
- worker_id = uuid1().hex
- party_id = str(party_id)
- if task:
- config_dir = job_utils.get_job_directory(job_id, role, party_id, task.f_component_name, task.f_task_id,
- str(task.f_task_version), worker_name.value, worker_id)
- log_dir = job_utils.get_job_log_directory(job_id, role, party_id, task.f_component_name)
- elif job_id and role and party_id:
- config_dir = job_utils.get_job_directory(job_id, role, party_id, worker_name.value, worker_id)
- log_dir = job_utils.get_job_log_directory(job_id, role, party_id, worker_name.value, worker_id)
- else:
- config_dir = job_utils.get_general_worker_directory(worker_name.value, worker_id)
- log_dir = job_utils.get_general_worker_log_directory(worker_name.value, worker_id)
- os.makedirs(config_dir, exist_ok=True)
- return worker_id, config_dir, log_dir
- @classmethod
- def get_config(cls, config_dir, config, log_dir):
- config_path = os.path.join(config_dir, "config.json")
- with open(config_path, 'w') as fw:
- fw.write(json_dumps(config))
- result_path = os.path.join(config_dir, "result.json")
- return config_path, result_path
- @classmethod
- def get_env(cls, job_id, provider_info):
- provider = ComponentProvider(**provider_info)
- env = provider.env.copy()
- env["PYTHONPATH"] = os.path.dirname(provider.path)
- if job_id:
- env["FATE_JOB_ID"] = job_id
- return env
- @classmethod
- def cmd_to_func_kwargs(cls, cmd):
- kwargs = {}
- for i in range(2, len(cmd), 2):
- kwargs[cmd[i].lstrip("--")] = cmd[i+1]
- return kwargs
- @classmethod
- @DB.connection_context()
- def save_worker_info(cls, task: Task, worker_name: WorkerName, worker_id, **kwargs):
- worker = WorkerInfo()
- ignore_attr = auto_date_timestamp_db_field()
- for attr, value in task.to_dict().items():
- if hasattr(worker, attr) and attr not in ignore_attr and value is not None:
- setattr(worker, attr, value)
- worker.f_create_time = current_timestamp()
- worker.f_worker_name = worker_name.value
- worker.f_worker_id = worker_id
- for k, v in kwargs.items():
- attr = f"f_{k}"
- if hasattr(worker, attr) and v is not None:
- setattr(worker, attr, v)
- rows = worker.save(force_insert=True)
- if rows != 1:
- raise Exception("save worker info failed")
- @classmethod
- @DB.connection_context()
- def kill_task_all_workers(cls, task: Task):
- schedule_logger(task.f_job_id).info(start_log("kill all workers", task=task))
- workers_info = WorkerInfo.query(task_id=task.f_task_id, task_version=task.f_task_version, role=task.f_role,
- party_id=task.f_party_id)
- for worker_info in workers_info:
- schedule_logger(task.f_job_id).info(
- start_log(f"kill {worker_info.f_worker_name}({worker_info.f_run_pid})", task=task))
- try:
- cls.kill_worker(worker_info)
- schedule_logger(task.f_job_id).info(
- successful_log(f"kill {worker_info.f_worker_name}({worker_info.f_run_pid})", task=task))
- except Exception as e:
- schedule_logger(task.f_job_id).warning(
- failed_log(f"kill {worker_info.f_worker_name}({worker_info.f_run_pid})", task=task), exc_info=True)
- schedule_logger(task.f_job_id).info(successful_log("kill all workers", task=task))
- @classmethod
- def kill_worker(cls, worker_info: WorkerInfo):
- process_utils.kill_process(pid=worker_info.f_run_pid, expected_cmdline=worker_info.f_cmd)
|