provider_manager.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os.path
  17. import sys
  18. from copy import deepcopy
  19. from fate_arch.common import file_utils
  20. from fate_arch.common.versions import get_versions
  21. from fate_flow.controller.version_controller import VersionController
  22. from fate_flow.entity import ComponentProvider
  23. from fate_flow.db.component_registry import ComponentRegistry
  24. from fate_flow.db.job_default_config import JobDefaultConfig
  25. from fate_flow.manager.worker_manager import WorkerManager
  26. from fate_flow.entity.types import WorkerName
  27. from fate_flow.settings import stat_logger
  28. from fate_flow.utils.base_utils import get_fate_flow_python_directory
  29. class ProviderManager:
  30. @classmethod
  31. def register_default_providers(cls):
  32. code, result = cls.register_fate_flow_provider()
  33. if code != 0:
  34. raise Exception(f"register fate flow tools component failed")
  35. code, result, provider = cls.register_default_fate_provider()
  36. if code != 0:
  37. raise Exception(f"register default fate algorithm component failed")
  38. return provider
  39. @classmethod
  40. def register_fate_flow_provider(cls):
  41. provider = cls.get_fate_flow_provider()
  42. return WorkerManager.start_general_worker(worker_name=WorkerName.PROVIDER_REGISTRAR, provider=provider, run_in_subprocess=False)
  43. @classmethod
  44. def register_default_fate_provider(cls):
  45. provider = cls.get_default_fate_provider()
  46. sys.path.append(provider.env["PYTHONPATH"])
  47. code, result = WorkerManager.start_general_worker(worker_name=WorkerName.PROVIDER_REGISTRAR, provider=provider, run_in_subprocess=False)
  48. return code, result, provider
  49. @classmethod
  50. def get_fate_flow_provider(cls):
  51. path = get_fate_flow_python_directory("fate_flow")
  52. provider = ComponentProvider(name="fate_flow", version=get_versions()["FATEFlow"], path=path, class_path=ComponentRegistry.get_default_class_path())
  53. return provider
  54. @classmethod
  55. def get_default_fate_provider_env(cls):
  56. provider = cls.get_default_fate_provider()
  57. return provider.env
  58. @classmethod
  59. def get_default_fate_provider(cls):
  60. path = JobDefaultConfig.default_component_provider_path.split("/")
  61. path = file_utils.get_fate_python_directory(*path)
  62. if not os.path.exists(path):
  63. raise Exception(f"default fate provider not exists: {path}")
  64. provider = ComponentProvider(name="fate", version=get_versions()["FATE"], path=path, class_path=ComponentRegistry.get_default_class_path())
  65. return provider
  66. @classmethod
  67. def if_default_provider(cls, provider: ComponentProvider):
  68. if provider == cls.get_fate_flow_provider() or provider == cls.get_default_fate_provider():
  69. return True
  70. else:
  71. return False
  72. @classmethod
  73. def fill_fate_flow_provider(cls, dsl):
  74. dest_dsl = deepcopy(dsl)
  75. fate_flow_provider = cls.get_fate_flow_provider()
  76. support_components = ComponentRegistry.get_provider_components(fate_flow_provider.name, fate_flow_provider.version)
  77. provider_key = f"{fate_flow_provider.name}@{fate_flow_provider.version}"
  78. for cpn, config in dsl["components"].items():
  79. if config["module"] in support_components:
  80. dest_dsl["components"][cpn]["provider"] = provider_key
  81. return dest_dsl
  82. @classmethod
  83. def get_fate_flow_component_module(cls):
  84. fate_flow_provider = cls.get_fate_flow_provider()
  85. return ComponentRegistry.get_provider_components(fate_flow_provider.name, fate_flow_provider.version)
  86. @classmethod
  87. def get_provider_object(cls, provider_info, check_registration=True):
  88. name, version = provider_info["name"], provider_info["version"]
  89. if check_registration and ComponentRegistry.get_providers().get(name, {}).get(version, None) is None:
  90. raise Exception(f"{name} {version} provider is not registered")
  91. path = ComponentRegistry.get_providers().get(name, {}).get(version, {}).get("path", [])
  92. class_path = ComponentRegistry.get_providers().get(name, {}).get(version, {}).get("class_path", None)
  93. if class_path is None:
  94. class_path = ComponentRegistry.REGISTRY["default_settings"]["class_path"]
  95. return ComponentProvider(name=name, version=version, path=path, class_path=class_path)
  96. @classmethod
  97. def get_job_provider_group(cls, dsl_parser, role, party_id, components: list = None, check_registration=True,
  98. runtime_conf=None, check_version=False, is_scheduler=False):
  99. if is_scheduler:
  100. # local provider
  101. providers_info = dsl_parser.get_job_providers(provider_detail=ComponentRegistry.REGISTRY)
  102. else:
  103. providers_info = dsl_parser.get_job_providers(provider_detail=ComponentRegistry.REGISTRY, conf=runtime_conf,
  104. local_role=role, local_party_id=party_id)
  105. if check_version:
  106. VersionController.job_provider_version_check(providers_info, local_role=role, local_party_id=party_id)
  107. group = {}
  108. if role in providers_info and not is_scheduler:
  109. providers_info = providers_info.get(role, {}).get(int(party_id), {}) or\
  110. providers_info.get(role, {}).get(str(party_id), {})
  111. if components is not None:
  112. _providers_info = {}
  113. for component_name in components:
  114. _providers_info[component_name] = providers_info.get(component_name)
  115. providers_info = _providers_info
  116. for component_name, provider_info in providers_info.items():
  117. provider = cls.get_provider_object(provider_info["provider"], check_registration=check_registration)
  118. group_key = "@".join([provider.name, provider.version])
  119. if group_key not in group:
  120. group[group_key] = {
  121. "provider": provider.to_dict(),
  122. "if_default_provider": cls.if_default_provider(provider),
  123. "components": [component_name]
  124. }
  125. else:
  126. group[group_key]["components"].append(component_name)
  127. return group
  128. @classmethod
  129. def get_component_provider(cls, dsl_parser, component_name):
  130. providers = dsl_parser.get_job_providers(provider_detail=ComponentRegistry.REGISTRY)
  131. return cls.get_provider_object(providers[component_name]["provider"])
  132. @classmethod
  133. def get_component_parameters(cls, dsl_parser, component_name, role, party_id, provider: ComponentProvider = None, previous_components_parameters: dict = None):
  134. if not provider:
  135. provider = cls.get_component_provider(dsl_parser=dsl_parser,
  136. component_name=component_name)
  137. parameters = dsl_parser.parse_component_parameters(component_name,
  138. ComponentRegistry.REGISTRY,
  139. provider.name,
  140. provider.version,
  141. local_role=role,
  142. local_party_id=int(party_id),
  143. previous_parameters=previous_components_parameters)
  144. user_specified_parameters = dsl_parser.parse_user_specified_component_parameters(component_name,
  145. ComponentRegistry.REGISTRY,
  146. provider.name,
  147. provider.version,
  148. local_role=role,
  149. local_party_id=int(party_id),
  150. previous_parameters=previous_components_parameters)
  151. return parameters, user_specified_parameters
  152. @classmethod
  153. def get_component_run_info(cls, dsl_parser, component_name, role, party_id, previous_components_parameters: dict = None):
  154. provider = cls.get_component_provider(dsl_parser, component_name)
  155. parameters, user_specified_parameters = cls.get_component_parameters(dsl_parser, component_name, role, party_id, provider, previous_components_parameters)
  156. return provider, parameters, user_specified_parameters