components.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import inspect
  18. import typing
  19. from pathlib import Path
  20. from federatedml.model_base import ModelBase
  21. from federatedml.param.base_param import BaseParam
  22. from federatedml.util import LOGGER
  23. _ml_base = Path(__file__).resolve().parent.parent.parent
  24. class _RunnerDecorator:
  25. def __init__(self, meta: "ComponentMeta") -> None:
  26. self._roles = set()
  27. self._meta = meta
  28. @property
  29. def on_guest(self):
  30. self._roles.add("guest")
  31. return self
  32. @property
  33. def on_host(self):
  34. self._roles.add("host")
  35. return self
  36. @property
  37. def on_arbiter(self):
  38. self._roles.add("arbiter")
  39. return self
  40. @property
  41. def on_local(self):
  42. self._roles.add("local")
  43. return self
  44. def __call__(self, cls):
  45. if inspect.isclass(cls) and issubclass(cls, ModelBase):
  46. for role in self._roles:
  47. self._meta._role_to_runner_cls[role] = cls
  48. elif inspect.isfunction(cls):
  49. for role in self._roles:
  50. self._meta._role_to_runner_cls_getter[role] = cls
  51. else:
  52. raise NotImplementedError(f"type of {cls} not supported")
  53. return cls
  54. class ComponentMeta:
  55. __name_to_obj: typing.Dict[str, "ComponentMeta"] = {}
  56. def __init__(self, name, *others) -> None:
  57. if len(others) > 0:
  58. self._alias = [name, *others]
  59. self._name = "|".join(self._alias)
  60. else:
  61. self._alias = [name]
  62. self._name = name
  63. self._role_to_runner_cls = {}
  64. self._role_to_runner_cls_getter = {} # lazy
  65. self._param_cls = None
  66. self._param_cls_getter = None # lazy
  67. for alias in self._alias:
  68. self.__name_to_obj[alias] = self
  69. @property
  70. def name(self):
  71. return self._name
  72. @property
  73. def alias(self):
  74. return self._alias
  75. @classmethod
  76. def get_meta(cls, name):
  77. return cls.__name_to_obj[name]
  78. @property
  79. def bind_runner(self):
  80. return _RunnerDecorator(self)
  81. @property
  82. def bind_param(self):
  83. def _wrap(cls):
  84. if inspect.isclass(cls) and issubclass(cls, BaseParam):
  85. self._param_cls = cls
  86. elif inspect.isfunction(cls):
  87. self._param_cls_getter = cls
  88. else:
  89. raise NotImplementedError(f"type of {cls} not supported")
  90. return cls
  91. return _wrap
  92. def _get_runner(self, role: str):
  93. if role in self._role_to_runner_cls:
  94. runner_class = self._role_to_runner_cls[role]
  95. elif role in self._role_to_runner_cls_getter:
  96. runner_class = self._role_to_runner_cls_getter[role]()
  97. else:
  98. raise ModuleNotFoundError(
  99. f"Runner for component `{self.name}` at role `{role}` not found"
  100. )
  101. runner_class.set_component_name(self.alias[0])
  102. return runner_class
  103. def get_run_obj(self, role: str):
  104. return self._get_runner(role)()
  105. def get_run_obj_name(self, role: str) -> str:
  106. return self._get_runner(role).__name__
  107. def get_param_obj(self, cpn_name: str):
  108. if self._param_cls is not None:
  109. param_obj = self._param_cls()
  110. elif self._param_cls_getter is not None:
  111. param_obj = self._param_cls_getter()()
  112. else:
  113. raise ModuleNotFoundError(f"Param for component `{self.name}` not found")
  114. return param_obj.set_name(f"{self.name}#{cpn_name}")
  115. def get_supported_roles(self):
  116. return set(self._role_to_runner_cls) | set(self._role_to_runner_cls_getter)
  117. def _get_module_name_by_path(path, base):
  118. return '.'.join(path.resolve().relative_to(base.resolve()).with_suffix('').parts)
  119. def _search_components(path, base):
  120. try:
  121. module_name = _get_module_name_by_path(path, base)
  122. module = importlib.import_module(module_name)
  123. except ImportError as e:
  124. # or skip ?
  125. raise e
  126. _obj_pairs = inspect.getmembers(module, lambda obj: isinstance(obj, ComponentMeta))
  127. return _obj_pairs, module_name
  128. class Components:
  129. provider_version = None
  130. provider_name = None
  131. provider_path = None
  132. @classmethod
  133. def _module_base(cls):
  134. return Path(cls.provider_path).resolve().parent
  135. @classmethod
  136. def _components_base(cls):
  137. return Path(cls.provider_path, 'components').resolve()
  138. @classmethod
  139. def get_names(cls) -> typing.Dict[str, dict]:
  140. names = {}
  141. for p in cls._components_base().glob("**/*.py"):
  142. obj_pairs, module_name = _search_components(p, cls._module_base())
  143. for name, obj in obj_pairs:
  144. for alias in obj.alias:
  145. names[alias] = {"module": module_name}
  146. LOGGER.info(
  147. f"component register {obj.name} with cache info {module_name}"
  148. )
  149. return names
  150. @classmethod
  151. def get(cls, name: str, cache) -> ComponentMeta:
  152. if cache:
  153. importlib.import_module(cache[name]["module"])
  154. else:
  155. for p in cls._components_base().glob("**/*.py"):
  156. module_name = _get_module_name_by_path(p, cls._module_base())
  157. importlib.import_module(module_name)
  158. return ComponentMeta.get_meta(name)