|
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import abc
- import typing
- from fate_flow.utils.log_utils import getLogger
- from fate_flow.components.param_extract import ParamExtract
- from fate_flow.scheduling_apps.client.tracker_client import TrackerClient
- LOGGER = getLogger()
- class ComponentInputProtocol(metaclass=abc.ABCMeta):
- @property
- @abc.abstractmethod
- def parameters(self) -> dict:
- ...
- @property
- @abc.abstractmethod
- def flow_feeded_parameters(self) -> dict:
- ...
- @property
- @abc.abstractmethod
- def roles(self):
- ...
- @property
- @abc.abstractmethod
- def job_parameters(self):
- ...
- @property
- @abc.abstractmethod
- def tracker(self):
- ...
- @property
- @abc.abstractmethod
- def task_version_id(self):
- ...
- @property
- @abc.abstractmethod
- def checkpoint_manager(self):
- ...
- @property
- @abc.abstractmethod
- def datasets(self):
- ...
- @property
- @abc.abstractmethod
- def models(self):
- ...
- class ComponentOutput:
- def __init__(self, data, models, cache: typing.List[tuple], serialize: bool = True) -> None:
- self._data = data
- if not isinstance(self._data, list):
- self._data = [data]
- self._models = models
- if self._models is None:
- self._models = {}
- self._cache = cache
- if not isinstance(self._cache, list):
- self._cache = [cache]
- self.serialize = serialize
- @property
- def data(self):
- return self._data
- @property
- def model(self):
- if not self.serialize:
- return self._models
- serialized_models: typing.Dict[str, typing.Tuple[str, bytes]] = {}
- for model_name, buffer_object in self._models.items():
- serialized_string = buffer_object.SerializeToString()
- if not serialized_string:
- from fate_arch.protobuf.python import default_empty_fill_pb2
- buffer_object = default_empty_fill_pb2.DefaultEmptyFillMessage()
- buffer_object.flag = "set"
- serialized_string = buffer_object.SerializeToString()
- pb_name = type(buffer_object).__name__
- serialized_models[model_name] = (pb_name, serialized_string)
- return serialized_models
- @property
- def cache(self):
- return self._cache
- class ComponentBase(metaclass=abc.ABCMeta):
- def __init__(self):
- self.task_version_id = ""
- self.tracker: TrackerClient = None
- self.checkpoint_manager = None
- self.model_output = None
- self.data_output = None
- self.cache_output = None
- self.serialize = True
- @abc.abstractmethod
- def _run(self, cpn_input: ComponentInputProtocol):
- """to be implemented"""
- ...
- def _retry(self, cpn_input: ComponentInputProtocol):
- ...
- # raise NotImplementedError(f"_retry for {type(self)} not implemented")
- def run(self, cpn_input: ComponentInputProtocol, retry: bool = True):
- self.task_version_id = cpn_input.task_version_id
- self.tracker = cpn_input.tracker
- self.checkpoint_manager = cpn_input.checkpoint_manager
- # retry
- if (
- retry
- and hasattr(self, '_retry')
- and callable(self._retry)
- and self.checkpoint_manager is not None
- and self.checkpoint_manager.latest_checkpoint is not None
- ):
- self._retry(cpn_input=cpn_input)
- # normal
- else:
- self._run(cpn_input=cpn_input)
- return ComponentOutput(data=self.save_data(), models=self.export_model(), cache=self.save_cache(), serialize=self.serialize)
- def save_data(self):
- return self.data_output
- def export_model(self):
- return self.model_output
- def save_cache(self):
- return self.cache_output
- class _RunnerDecorator:
- def __init__(self, meta) -> None:
- self._roles = set()
- self._meta = meta
- @property
- def on_guest(self):
- self._roles.add("guest")
- return self
- @property
- def on_host(self):
- self._roles.add("host")
- return self
- @property
- def on_arbiter(self):
- self._roles.add("arbiter")
- return self
- @property
- def on_local(self):
- self._roles.add("local")
- return self
- def __call__(self, cls):
- if issubclass(cls, ComponentBase):
- for role in self._roles:
- self._meta._role_to_runner_cls[role] = cls
- else:
- raise NotImplementedError(f"type of {cls} not supported")
- return cls
- class ComponentMeta:
- __name_to_obj: typing.Dict[str, "ComponentMeta"] = {}
- def __init__(self, name) -> None:
- self.name = name
- self._role_to_runner_cls = {}
- self._param_cls = None
- self.__name_to_obj[name] = self
- @property
- def bind_runner(self):
- return _RunnerDecorator(self)
- @property
- def bind_param(self):
- def _wrap(cls):
- self._param_cls = cls
- return cls
- return _wrap
- def register_info(self):
- return {
- self.name: dict(
- module=self.__module__,
- )
- }
- @classmethod
- def get_meta(cls, name):
- return cls.__name_to_obj[name]
- def _get_runner(self, role: str):
- if role not in self._role_to_runner_cls:
- raise ModuleNotFoundError(
- f"Runner for component `{self.name}` at role `{role}` not found"
- )
- return self._role_to_runner_cls[role]
- def get_run_obj(self, role: str):
- return self._get_runner(role)()
- def get_run_obj_name(self, role: str) -> str:
- return self._get_runner(role).__name__
- def get_param_obj(self, cpn_name: str):
- if self._param_cls is None:
- raise ModuleNotFoundError(f"Param for component `{self.name}` not found")
- param_obj = self._param_cls().set_name(f"{self.name}#{cpn_name}")
- return param_obj
- def get_supported_roles(self):
- roles = set(self._role_to_runner_cls.keys())
- if not roles:
- raise ModuleNotFoundError(f"roles for {self.name} is empty")
- return roles
- class BaseParam(object):
- def set_name(self, name: str):
- self._name = name
- return self
- def check(self):
- raise NotImplementedError("Parameter Object should have be check")
- def as_dict(self):
- return ParamExtract().change_param_to_dict(self)
- @classmethod
- def from_dict(cls, conf):
- obj = cls()
- obj.update(conf)
- return obj
- def update(self, conf, allow_redundant=False):
- return ParamExtract().recursive_parse_param_from_config(
- param=self,
- config_json=conf,
- param_parse_depth=0,
- valid_check=not allow_redundant,
- name=self._name,
- )
|