worker_manager.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import subprocess
  18. import sys
  19. from uuid import uuid1
  20. from fate_arch.common.base_utils import current_timestamp, json_dumps
  21. from fate_arch.common.file_utils import load_json_conf
  22. from fate_arch.metastore.base_model import auto_date_timestamp_db_field
  23. from fate_flow.db.db_models import DB, Task, WorkerInfo
  24. from fate_flow.db.runtime_config import RuntimeConfig
  25. from fate_flow.entity import ComponentProvider, RunParameters
  26. from fate_flow.entity.types import WorkerName
  27. from fate_flow.settings import stat_logger
  28. from fate_flow.utils import job_utils, process_utils
  29. from fate_flow.utils.log_utils import failed_log, ready_log, schedule_logger, start_log, successful_log
  30. class WorkerManager:
  31. @classmethod
  32. def start_general_worker(cls, worker_name: WorkerName, job_id="", role="", party_id=0, provider: ComponentProvider = None,
  33. initialized_config: dict = None, run_in_subprocess=True, **kwargs):
  34. if RuntimeConfig.DEBUG:
  35. run_in_subprocess = True
  36. participate = locals()
  37. worker_id, config_dir, log_dir = cls.get_process_dirs(worker_name=worker_name,
  38. job_id=job_id,
  39. role=role,
  40. party_id=party_id)
  41. if worker_name in [WorkerName.PROVIDER_REGISTRAR, WorkerName.DEPENDENCE_UPLOAD]:
  42. if not provider:
  43. raise ValueError("no provider argument")
  44. config = {
  45. "provider": provider.to_dict()
  46. }
  47. if worker_name == WorkerName.PROVIDER_REGISTRAR:
  48. from fate_flow.worker.provider_registrar import ProviderRegistrar
  49. module = ProviderRegistrar
  50. module_file_path = sys.modules[ProviderRegistrar.__module__].__file__
  51. specific_cmd = []
  52. elif worker_name == WorkerName.DEPENDENCE_UPLOAD:
  53. from fate_flow.worker.dependence_upload import DependenceUpload
  54. module = DependenceUpload
  55. module_file_path = sys.modules[DependenceUpload.__module__].__file__
  56. specific_cmd = [
  57. '--dependence_type', kwargs.get("dependence_type")
  58. ]
  59. provider_info = provider.to_dict()
  60. elif worker_name is WorkerName.TASK_INITIALIZER:
  61. if not initialized_config:
  62. raise ValueError("no initialized_config argument")
  63. config = initialized_config
  64. from fate_flow.worker.task_initializer import TaskInitializer
  65. module = TaskInitializer
  66. module_file_path = sys.modules[TaskInitializer.__module__].__file__
  67. specific_cmd = []
  68. provider_info = initialized_config["provider"]
  69. else:
  70. raise Exception(f"not support {worker_name} worker")
  71. config_path, result_path = cls.get_config(config_dir=config_dir, config=config, log_dir=log_dir)
  72. process_cmd = [
  73. sys.executable or "python3",
  74. module_file_path,
  75. "--config", config_path,
  76. '--result', result_path,
  77. "--log_dir", log_dir,
  78. "--parent_log_dir", os.path.dirname(log_dir),
  79. "--worker_id", worker_id,
  80. "--run_ip", RuntimeConfig.JOB_SERVER_HOST,
  81. "--job_server", f"{RuntimeConfig.JOB_SERVER_HOST}:{RuntimeConfig.HTTP_PORT}",
  82. ]
  83. if job_id:
  84. process_cmd.extend([
  85. "--job_id", job_id,
  86. "--role", role,
  87. "--party_id", party_id,
  88. ])
  89. process_cmd.extend(specific_cmd)
  90. if run_in_subprocess:
  91. p = process_utils.run_subprocess(job_id=job_id, config_dir=config_dir, process_cmd=process_cmd,
  92. added_env=cls.get_env(job_id, provider_info), log_dir=log_dir,
  93. cwd_dir=config_dir, process_name=worker_name.value, process_id=worker_id)
  94. participate["pid"] = p.pid
  95. if job_id and role and party_id:
  96. logger = schedule_logger(job_id)
  97. msg = f"{worker_name} worker {worker_id} subprocess {p.pid}"
  98. else:
  99. logger = stat_logger
  100. msg = f"{worker_name} worker {worker_id} subprocess {p.pid}"
  101. logger.info(ready_log(msg=msg, role=role, party_id=party_id))
  102. # asynchronous
  103. if worker_name in [WorkerName.DEPENDENCE_UPLOAD]:
  104. if kwargs.get("callback") and kwargs.get("callback_param"):
  105. callback_param = {}
  106. participate.update(participate.get("kwargs", {}))
  107. for k, v in participate.items():
  108. if k in kwargs.get("callback_param"):
  109. callback_param[k] = v
  110. kwargs.get("callback")(**callback_param)
  111. else:
  112. try:
  113. p.wait(timeout=120)
  114. if p.returncode == 0:
  115. logger.info(successful_log(msg=msg, role=role, party_id=party_id))
  116. else:
  117. logger.info(failed_log(msg=msg, role=role, party_id=party_id))
  118. if p.returncode == 0:
  119. return p.returncode, load_json_conf(result_path)
  120. else:
  121. std_path = process_utils.get_std_path(log_dir=log_dir, process_name=worker_name.value, process_id=worker_id)
  122. raise Exception(f"run error, please check logs: {std_path}, {log_dir}/INFO.log")
  123. except subprocess.TimeoutExpired as e:
  124. err = failed_log(msg=f"{msg} run timeout", role=role, party_id=party_id)
  125. logger.exception(err)
  126. raise Exception(err)
  127. finally:
  128. try:
  129. p.kill()
  130. p.poll()
  131. except Exception as e:
  132. logger.exception(e)
  133. else:
  134. kwargs = cls.cmd_to_func_kwargs(process_cmd)
  135. code, message, result = module().run(**kwargs)
  136. if code == 0:
  137. return code, result
  138. else:
  139. raise Exception(message)
  140. @classmethod
  141. def start_task_worker(cls, worker_name, task: Task, task_parameters: RunParameters = None,
  142. executable: list = None, extra_env: dict = None, **kwargs):
  143. worker_id, config_dir, log_dir = cls.get_process_dirs(worker_name=worker_name,
  144. job_id=task.f_job_id,
  145. role=task.f_role,
  146. party_id=task.f_party_id,
  147. task=task)
  148. session_id = job_utils.generate_session_id(task.f_task_id, task.f_task_version, task.f_role, task.f_party_id)
  149. federation_session_id = job_utils.generate_task_version_id(task.f_task_id, task.f_task_version)
  150. info_kwargs = {}
  151. specific_cmd = []
  152. if worker_name is WorkerName.TASK_EXECUTOR:
  153. from fate_flow.worker.task_executor import TaskExecutor
  154. module_file_path = sys.modules[TaskExecutor.__module__].__file__
  155. else:
  156. raise Exception(f"not support {worker_name} worker")
  157. if task_parameters is None:
  158. task_parameters = RunParameters(**job_utils.get_job_parameters(task.f_job_id, task.f_role, task.f_party_id))
  159. config = task_parameters.to_dict()
  160. config["src_user"] = kwargs.get("src_user")
  161. config_path, result_path = cls.get_config(config_dir=config_dir, config=config, log_dir=log_dir)
  162. env = cls.get_env(task.f_job_id, task.f_provider_info)
  163. if executable:
  164. process_cmd = executable
  165. else:
  166. process_cmd = [env.get("PYTHON_ENV") or sys.executable or "python3"]
  167. common_cmd = [
  168. module_file_path,
  169. "--job_id", task.f_job_id,
  170. "--component_name", task.f_component_name,
  171. "--task_id", task.f_task_id,
  172. "--task_version", task.f_task_version,
  173. "--role", task.f_role,
  174. "--party_id", task.f_party_id,
  175. "--config", config_path,
  176. '--result', result_path,
  177. "--log_dir", log_dir,
  178. "--parent_log_dir", os.path.dirname(log_dir),
  179. "--worker_id", worker_id,
  180. "--run_ip", RuntimeConfig.JOB_SERVER_HOST,
  181. "--run_port", RuntimeConfig.HTTP_PORT,
  182. "--job_server", f"{RuntimeConfig.JOB_SERVER_HOST}:{RuntimeConfig.HTTP_PORT}",
  183. "--session_id", session_id,
  184. "--federation_session_id", federation_session_id
  185. ]
  186. process_cmd.extend(common_cmd)
  187. process_cmd.extend(specific_cmd)
  188. if extra_env:
  189. env.update(extra_env)
  190. schedule_logger(task.f_job_id).info(
  191. f"task {task.f_task_id} {task.f_task_version} on {task.f_role} {task.f_party_id} {worker_name} worker subprocess is ready")
  192. p = process_utils.run_subprocess(job_id=task.f_job_id, config_dir=config_dir, process_cmd=process_cmd,
  193. added_env=env, log_dir=log_dir, cwd_dir=config_dir, process_name=worker_name.value,
  194. process_id=worker_id)
  195. cls.save_worker_info(task=task, worker_name=worker_name, worker_id=worker_id, run_ip=RuntimeConfig.JOB_SERVER_HOST, run_pid=p.pid, config=config, cmd=process_cmd, **info_kwargs)
  196. return {"run_pid": p.pid, "worker_id": worker_id, "cmd": process_cmd}
  197. @classmethod
  198. def get_process_dirs(cls, worker_name: WorkerName, job_id=None, role=None, party_id=None, task: Task = None):
  199. worker_id = uuid1().hex
  200. party_id = str(party_id)
  201. if task:
  202. config_dir = job_utils.get_job_directory(job_id, role, party_id, task.f_component_name, task.f_task_id,
  203. str(task.f_task_version), worker_name.value, worker_id)
  204. log_dir = job_utils.get_job_log_directory(job_id, role, party_id, task.f_component_name)
  205. elif job_id and role and party_id:
  206. config_dir = job_utils.get_job_directory(job_id, role, party_id, worker_name.value, worker_id)
  207. log_dir = job_utils.get_job_log_directory(job_id, role, party_id, worker_name.value, worker_id)
  208. else:
  209. config_dir = job_utils.get_general_worker_directory(worker_name.value, worker_id)
  210. log_dir = job_utils.get_general_worker_log_directory(worker_name.value, worker_id)
  211. os.makedirs(config_dir, exist_ok=True)
  212. return worker_id, config_dir, log_dir
  213. @classmethod
  214. def get_config(cls, config_dir, config, log_dir):
  215. config_path = os.path.join(config_dir, "config.json")
  216. with open(config_path, 'w') as fw:
  217. fw.write(json_dumps(config))
  218. result_path = os.path.join(config_dir, "result.json")
  219. return config_path, result_path
  220. @classmethod
  221. def get_env(cls, job_id, provider_info):
  222. provider = ComponentProvider(**provider_info)
  223. env = provider.env.copy()
  224. env["PYTHONPATH"] = os.path.dirname(provider.path)
  225. if job_id:
  226. env["FATE_JOB_ID"] = job_id
  227. return env
  228. @classmethod
  229. def cmd_to_func_kwargs(cls, cmd):
  230. kwargs = {}
  231. for i in range(2, len(cmd), 2):
  232. kwargs[cmd[i].lstrip("--")] = cmd[i+1]
  233. return kwargs
  234. @classmethod
  235. @DB.connection_context()
  236. def save_worker_info(cls, task: Task, worker_name: WorkerName, worker_id, **kwargs):
  237. worker = WorkerInfo()
  238. ignore_attr = auto_date_timestamp_db_field()
  239. for attr, value in task.to_dict().items():
  240. if hasattr(worker, attr) and attr not in ignore_attr and value is not None:
  241. setattr(worker, attr, value)
  242. worker.f_create_time = current_timestamp()
  243. worker.f_worker_name = worker_name.value
  244. worker.f_worker_id = worker_id
  245. for k, v in kwargs.items():
  246. attr = f"f_{k}"
  247. if hasattr(worker, attr) and v is not None:
  248. setattr(worker, attr, v)
  249. rows = worker.save(force_insert=True)
  250. if rows != 1:
  251. raise Exception("save worker info failed")
  252. @classmethod
  253. @DB.connection_context()
  254. def kill_task_all_workers(cls, task: Task):
  255. schedule_logger(task.f_job_id).info(start_log("kill all workers", task=task))
  256. workers_info = WorkerInfo.query(task_id=task.f_task_id, task_version=task.f_task_version, role=task.f_role,
  257. party_id=task.f_party_id)
  258. for worker_info in workers_info:
  259. schedule_logger(task.f_job_id).info(
  260. start_log(f"kill {worker_info.f_worker_name}({worker_info.f_run_pid})", task=task))
  261. try:
  262. cls.kill_worker(worker_info)
  263. schedule_logger(task.f_job_id).info(
  264. successful_log(f"kill {worker_info.f_worker_name}({worker_info.f_run_pid})", task=task))
  265. except Exception as e:
  266. schedule_logger(task.f_job_id).warning(
  267. failed_log(f"kill {worker_info.f_worker_name}({worker_info.f_run_pid})", task=task), exc_info=True)
  268. schedule_logger(task.f_job_id).info(successful_log("kill all workers", task=task))
  269. @classmethod
  270. def kill_worker(cls, worker_info: WorkerInfo):
  271. process_utils.kill_process(pid=worker_info.f_run_pid, expected_cmdline=worker_info.f_cmd)