1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # from pipeline.backend.config import WorkMode
- from pipeline.utils.logger import LOGGER
- class OnlineCommand(object):
- def __init__(self, pipeline_obj):
- self.pipeline_obj = pipeline_obj
- """
- def _feed_online_conf(self):
- conf = {"initiator": self.pipeline_obj._get_initiator_conf(),
- "role": self.pipeline_obj._roles}
- predict_model_info = self.pipeline_obj.get_predict_model_info()
- train_work_mode = self.pipeline_obj.get_train_conf().get("job_parameters").get("common").get("work_mode")
- if train_work_mode != WorkMode.CLUSTER:
- raise ValueError(f"to use FATE serving online inference service, work mode must be CLUSTER.")
- conf["job_parameters"] = {"model_id": predict_model_info.model_id,
- "model_version": predict_model_info.model_version,
- "work_mode": WorkMode.CLUSTER}
- return conf
- """
- def _feed_online_conf(self):
- conf = {"initiator": self.pipeline_obj._get_initiator_conf(),
- "role": self.pipeline_obj._roles}
- predict_model_info = self.pipeline_obj.get_predict_model_info()
- conf["job_parameters"] = {"model_id": predict_model_info.model_id,
- "model_version": predict_model_info.model_version}
- return conf
- @LOGGER.catch(reraise=True)
- def load(self, file_path=None):
- if not self.pipeline_obj.is_deploy():
- raise ValueError(f"to load model for online inference, must deploy components first.")
- file_path = file_path if file_path else ""
- load_conf = self._feed_online_conf()
- load_conf["job_parameters"]["file_path"] = file_path
- self.pipeline_obj._job_invoker.load_model(load_conf)
- self.pipeline_obj._load = True
- @LOGGER.catch(reraise=True)
- def bind(self, service_id, *servings):
- if not self.pipeline_obj.is_deploy() or not self.pipeline_obj.is_load():
- raise ValueError(f"to bind model to online service, must deploy and load model first.")
- bind_conf = self._feed_online_conf()
- bind_conf["service_id"] = service_id
- bind_conf["servings"] = list(servings)
- self.pipeline_obj._job_invoker.bind_model(bind_conf)
- class ModelConvert(object):
- def __init__(self, pipeline_obj):
- self.pipeline_obj = pipeline_obj
- def _feed_homo_conf(self, framework_name):
- model_info = self.pipeline_obj.get_model_info()
- conf = {"role": self.pipeline_obj._initiator.role,
- "party_id": self.pipeline_obj._initiator.party_id,
- "model_id": model_info.model_id,
- "model_version": model_info.model_version
- }
- if framework_name:
- conf["framework_name"] = framework_name
- return conf
- @LOGGER.catch(reraise=True)
- def convert(self, framework_name=None):
- if self.pipeline_obj._train_dsl is None:
- raise ValueError("Before converting homo model, training should be finished!!!")
- conf = self._feed_homo_conf(framework_name)
- res_dict = self.pipeline_obj._job_invoker.convert_homo_model(conf)
- return res_dict
|