base_worker.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import argparse
  17. import os
  18. import sys
  19. import traceback
  20. import logging
  21. from fate_arch.common.base_utils import current_timestamp
  22. from fate_arch.common.file_utils import load_json_conf, dump_json_conf
  23. from fate_flow.utils.log_utils import getLogger, LoggerFactory, exception_to_trace_string
  24. from fate_flow.db.component_registry import ComponentRegistry
  25. from fate_flow.db.config_manager import ConfigManager
  26. from fate_flow.db.runtime_config import RuntimeConfig
  27. from fate_flow.entity.types import ProcessRole
  28. from fate_flow.entity import BaseEntity
  29. LOGGER = getLogger()
  30. class WorkerArgs(BaseEntity):
  31. def __init__(self, **kwargs):
  32. self.job_id = kwargs.get("job_id")
  33. self.component_name = kwargs.get("component_name")
  34. self.task_id = kwargs.get("task_id")
  35. self.task_version = kwargs.get("task_version")
  36. self.role = kwargs.get("role")
  37. self.party_id = kwargs.get("party_id")
  38. self.config = self.load_dict_attr(kwargs, "config")
  39. self.result = kwargs.get("result")
  40. self.log_dir = kwargs.get("log_dir")
  41. self.parent_log_dir = kwargs.get("parent_log_dir")
  42. self.worker_id = kwargs.get("worker_id")
  43. self.run_ip = kwargs.get("run_ip")
  44. self.run_port = kwargs.get("run_port")
  45. self.job_server = kwargs.get("job_server")
  46. # TaskInitializer
  47. self.result = kwargs.get("result")
  48. self.dsl = self.load_dict_attr(kwargs, "dsl")
  49. self.runtime_conf = self.load_dict_attr(kwargs, "runtime_conf")
  50. self.train_runtime_conf = self.load_dict_attr(kwargs, "train_runtime_conf")
  51. self.pipeline_dsl = self.load_dict_attr(kwargs, "pipeline_dsl")
  52. # TaskSender & TaskReceiver
  53. self.session_id = kwargs.get("session_id")
  54. self.federation_session_id = kwargs.get("federation_session_id")
  55. # TaskSender
  56. self.receive_topic = kwargs.get("receive_topic")
  57. # TaskReceiver
  58. self.http_port = kwargs.get("http_port")
  59. self.grpc_port = kwargs.get("grpc_port")
  60. # Dependence Upload
  61. self.dependence_type = kwargs.get("dependence_type")
  62. @staticmethod
  63. def load_dict_attr(kwargs: dict, attr_name: str):
  64. return load_json_conf(kwargs[attr_name]) if kwargs.get(attr_name) else {}
  65. class BaseWorker:
  66. def __init__(self):
  67. self.args: WorkerArgs = None
  68. self.run_pid = None
  69. self.report_info = {}
  70. def run(self, **kwargs):
  71. result = {}
  72. code = 0
  73. message = ""
  74. start_time = current_timestamp()
  75. self.run_pid = os.getpid()
  76. try:
  77. self.args = self.get_args(**kwargs)
  78. RuntimeConfig.init_env()
  79. RuntimeConfig.set_process_role(ProcessRole(os.getenv("PROCESS_ROLE")))
  80. if RuntimeConfig.PROCESS_ROLE == ProcessRole.WORKER:
  81. LoggerFactory.LEVEL = logging.getLevelName(os.getenv("FATE_LOG_LEVEL", "INFO"))
  82. LoggerFactory.set_directory(directory=self.args.log_dir, parent_log_dir=self.args.parent_log_dir,
  83. append_to_parent_log=True, force=True)
  84. LOGGER.info(f"enter {self.__class__.__name__} worker in subprocess, pid: {self.run_pid}")
  85. else:
  86. LOGGER.info(f"enter {self.__class__.__name__} worker in driver process, pid: {self.run_pid}")
  87. LOGGER.info(f"log level: {logging.getLevelName(LoggerFactory.LEVEL)}")
  88. for env in {"VIRTUAL_ENV", "PYTHONPATH", "SPARK_HOME", "FATE_DEPLOY_BASE", "PROCESS_ROLE", "FATE_JOB_ID"}:
  89. LOGGER.info(f"{env}: {os.getenv(env)}")
  90. if self.args.job_server:
  91. RuntimeConfig.init_config(JOB_SERVER_HOST=self.args.job_server.split(':')[0],
  92. HTTP_PORT=self.args.job_server.split(':')[1])
  93. if not RuntimeConfig.LOAD_COMPONENT_REGISTRY:
  94. ComponentRegistry.load()
  95. if not RuntimeConfig.LOAD_CONFIG_MANAGER:
  96. ConfigManager.load()
  97. result = self._run()
  98. except Exception as e:
  99. LOGGER.exception(e)
  100. traceback.print_exc()
  101. try:
  102. self._handle_exception()
  103. except Exception as e:
  104. LOGGER.exception(e)
  105. code = 1
  106. message = exception_to_trace_string(e)
  107. finally:
  108. if self.args and self.args.result:
  109. dump_json_conf(result, self.args.result)
  110. end_time = current_timestamp()
  111. LOGGER.info(f"worker {self.__class__.__name__}, process role: {RuntimeConfig.PROCESS_ROLE}, pid: {self.run_pid}, elapsed: {end_time - start_time} ms")
  112. if RuntimeConfig.PROCESS_ROLE == ProcessRole.WORKER:
  113. sys.exit(code)
  114. else:
  115. return code, message, result
  116. def _run(self):
  117. raise NotImplementedError
  118. def _handle_exception(self):
  119. pass
  120. @staticmethod
  121. def get_args(**kwargs):
  122. if kwargs:
  123. return WorkerArgs(**kwargs)
  124. else:
  125. parser = argparse.ArgumentParser()
  126. for arg in WorkerArgs().to_dict():
  127. parser.add_argument(f"--{arg}", required=False)
  128. return WorkerArgs(**parser.parse_args().__dict__)