_base.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import abc
  17. import typing
  18. from fate_flow.utils.log_utils import getLogger
  19. from fate_flow.components.param_extract import ParamExtract
  20. from fate_flow.scheduling_apps.client.tracker_client import TrackerClient
  21. LOGGER = getLogger()
  22. class ComponentInputProtocol(metaclass=abc.ABCMeta):
  23. @property
  24. @abc.abstractmethod
  25. def parameters(self) -> dict:
  26. ...
  27. @property
  28. @abc.abstractmethod
  29. def flow_feeded_parameters(self) -> dict:
  30. ...
  31. @property
  32. @abc.abstractmethod
  33. def roles(self):
  34. ...
  35. @property
  36. @abc.abstractmethod
  37. def job_parameters(self):
  38. ...
  39. @property
  40. @abc.abstractmethod
  41. def tracker(self):
  42. ...
  43. @property
  44. @abc.abstractmethod
  45. def task_version_id(self):
  46. ...
  47. @property
  48. @abc.abstractmethod
  49. def checkpoint_manager(self):
  50. ...
  51. @property
  52. @abc.abstractmethod
  53. def datasets(self):
  54. ...
  55. @property
  56. @abc.abstractmethod
  57. def models(self):
  58. ...
  59. class ComponentOutput:
  60. def __init__(self, data, models, cache: typing.List[tuple], serialize: bool = True) -> None:
  61. self._data = data
  62. if not isinstance(self._data, list):
  63. self._data = [data]
  64. self._models = models
  65. if self._models is None:
  66. self._models = {}
  67. self._cache = cache
  68. if not isinstance(self._cache, list):
  69. self._cache = [cache]
  70. self.serialize = serialize
  71. @property
  72. def data(self):
  73. return self._data
  74. @property
  75. def model(self):
  76. if not self.serialize:
  77. return self._models
  78. serialized_models: typing.Dict[str, typing.Tuple[str, bytes]] = {}
  79. for model_name, buffer_object in self._models.items():
  80. serialized_string = buffer_object.SerializeToString()
  81. if not serialized_string:
  82. from fate_arch.protobuf.python import default_empty_fill_pb2
  83. buffer_object = default_empty_fill_pb2.DefaultEmptyFillMessage()
  84. buffer_object.flag = "set"
  85. serialized_string = buffer_object.SerializeToString()
  86. pb_name = type(buffer_object).__name__
  87. serialized_models[model_name] = (pb_name, serialized_string)
  88. return serialized_models
  89. @property
  90. def cache(self):
  91. return self._cache
  92. class ComponentBase(metaclass=abc.ABCMeta):
  93. def __init__(self):
  94. self.task_version_id = ""
  95. self.tracker: TrackerClient = None
  96. self.checkpoint_manager = None
  97. self.model_output = None
  98. self.data_output = None
  99. self.cache_output = None
  100. self.serialize = True
  101. @abc.abstractmethod
  102. def _run(self, cpn_input: ComponentInputProtocol):
  103. """to be implemented"""
  104. ...
  105. def _retry(self, cpn_input: ComponentInputProtocol):
  106. ...
  107. # raise NotImplementedError(f"_retry for {type(self)} not implemented")
  108. def run(self, cpn_input: ComponentInputProtocol, retry: bool = True):
  109. self.task_version_id = cpn_input.task_version_id
  110. self.tracker = cpn_input.tracker
  111. self.checkpoint_manager = cpn_input.checkpoint_manager
  112. # retry
  113. if (
  114. retry
  115. and hasattr(self, '_retry')
  116. and callable(self._retry)
  117. and self.checkpoint_manager is not None
  118. and self.checkpoint_manager.latest_checkpoint is not None
  119. ):
  120. self._retry(cpn_input=cpn_input)
  121. # normal
  122. else:
  123. self._run(cpn_input=cpn_input)
  124. return ComponentOutput(data=self.save_data(), models=self.export_model(), cache=self.save_cache(), serialize=self.serialize)
  125. def save_data(self):
  126. return self.data_output
  127. def export_model(self):
  128. return self.model_output
  129. def save_cache(self):
  130. return self.cache_output
  131. class _RunnerDecorator:
  132. def __init__(self, meta) -> None:
  133. self._roles = set()
  134. self._meta = meta
  135. @property
  136. def on_guest(self):
  137. self._roles.add("guest")
  138. return self
  139. @property
  140. def on_host(self):
  141. self._roles.add("host")
  142. return self
  143. @property
  144. def on_arbiter(self):
  145. self._roles.add("arbiter")
  146. return self
  147. @property
  148. def on_local(self):
  149. self._roles.add("local")
  150. return self
  151. def __call__(self, cls):
  152. if issubclass(cls, ComponentBase):
  153. for role in self._roles:
  154. self._meta._role_to_runner_cls[role] = cls
  155. else:
  156. raise NotImplementedError(f"type of {cls} not supported")
  157. return cls
  158. class ComponentMeta:
  159. __name_to_obj: typing.Dict[str, "ComponentMeta"] = {}
  160. def __init__(self, name) -> None:
  161. self.name = name
  162. self._role_to_runner_cls = {}
  163. self._param_cls = None
  164. self.__name_to_obj[name] = self
  165. @property
  166. def bind_runner(self):
  167. return _RunnerDecorator(self)
  168. @property
  169. def bind_param(self):
  170. def _wrap(cls):
  171. self._param_cls = cls
  172. return cls
  173. return _wrap
  174. def register_info(self):
  175. return {
  176. self.name: dict(
  177. module=self.__module__,
  178. )
  179. }
  180. @classmethod
  181. def get_meta(cls, name):
  182. return cls.__name_to_obj[name]
  183. def _get_runner(self, role: str):
  184. if role not in self._role_to_runner_cls:
  185. raise ModuleNotFoundError(
  186. f"Runner for component `{self.name}` at role `{role}` not found"
  187. )
  188. return self._role_to_runner_cls[role]
  189. def get_run_obj(self, role: str):
  190. return self._get_runner(role)()
  191. def get_run_obj_name(self, role: str) -> str:
  192. return self._get_runner(role).__name__
  193. def get_param_obj(self, cpn_name: str):
  194. if self._param_cls is None:
  195. raise ModuleNotFoundError(f"Param for component `{self.name}` not found")
  196. param_obj = self._param_cls().set_name(f"{self.name}#{cpn_name}")
  197. return param_obj
  198. def get_supported_roles(self):
  199. roles = set(self._role_to_runner_cls.keys())
  200. if not roles:
  201. raise ModuleNotFoundError(f"roles for {self.name} is empty")
  202. return roles
  203. class BaseParam(object):
  204. def set_name(self, name: str):
  205. self._name = name
  206. return self
  207. def check(self):
  208. raise NotImplementedError("Parameter Object should have be check")
  209. def as_dict(self):
  210. return ParamExtract().change_param_to_dict(self)
  211. @classmethod
  212. def from_dict(cls, conf):
  213. obj = cls()
  214. obj.update(conf)
  215. return obj
  216. def update(self, conf, allow_redundant=False):
  217. return ParamExtract().recursive_parse_param_from_config(
  218. param=self,
  219. config_json=conf,
  220. param_parse_depth=0,
  221. valid_check=not allow_redundant,
  222. name=self._name,
  223. )