_operation.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # from pipeline.backend.config import WorkMode
  17. from pipeline.utils.logger import LOGGER
  18. class OnlineCommand(object):
  19. def __init__(self, pipeline_obj):
  20. self.pipeline_obj = pipeline_obj
  21. """
  22. def _feed_online_conf(self):
  23. conf = {"initiator": self.pipeline_obj._get_initiator_conf(),
  24. "role": self.pipeline_obj._roles}
  25. predict_model_info = self.pipeline_obj.get_predict_model_info()
  26. train_work_mode = self.pipeline_obj.get_train_conf().get("job_parameters").get("common").get("work_mode")
  27. if train_work_mode != WorkMode.CLUSTER:
  28. raise ValueError(f"to use FATE serving online inference service, work mode must be CLUSTER.")
  29. conf["job_parameters"] = {"model_id": predict_model_info.model_id,
  30. "model_version": predict_model_info.model_version,
  31. "work_mode": WorkMode.CLUSTER}
  32. return conf
  33. """
  34. def _feed_online_conf(self):
  35. conf = {"initiator": self.pipeline_obj._get_initiator_conf(),
  36. "role": self.pipeline_obj._roles}
  37. predict_model_info = self.pipeline_obj.get_predict_model_info()
  38. conf["job_parameters"] = {"model_id": predict_model_info.model_id,
  39. "model_version": predict_model_info.model_version}
  40. return conf
  41. @LOGGER.catch(reraise=True)
  42. def load(self, file_path=None):
  43. if not self.pipeline_obj.is_deploy():
  44. raise ValueError(f"to load model for online inference, must deploy components first.")
  45. file_path = file_path if file_path else ""
  46. load_conf = self._feed_online_conf()
  47. load_conf["job_parameters"]["file_path"] = file_path
  48. self.pipeline_obj._job_invoker.load_model(load_conf)
  49. self.pipeline_obj._load = True
  50. @LOGGER.catch(reraise=True)
  51. def bind(self, service_id, *servings):
  52. if not self.pipeline_obj.is_deploy() or not self.pipeline_obj.is_load():
  53. raise ValueError(f"to bind model to online service, must deploy and load model first.")
  54. bind_conf = self._feed_online_conf()
  55. bind_conf["service_id"] = service_id
  56. bind_conf["servings"] = list(servings)
  57. self.pipeline_obj._job_invoker.bind_model(bind_conf)
  58. class ModelConvert(object):
  59. def __init__(self, pipeline_obj):
  60. self.pipeline_obj = pipeline_obj
  61. def _feed_homo_conf(self, framework_name):
  62. model_info = self.pipeline_obj.get_model_info()
  63. conf = {"role": self.pipeline_obj._initiator.role,
  64. "party_id": self.pipeline_obj._initiator.party_id,
  65. "model_id": model_info.model_id,
  66. "model_version": model_info.model_version
  67. }
  68. if framework_name:
  69. conf["framework_name"] = framework_name
  70. return conf
  71. @LOGGER.catch(reraise=True)
  72. def convert(self, framework_name=None):
  73. if self.pipeline_obj._train_dsl is None:
  74. raise ValueError("Before converting homo model, training should be finished!!!")
  75. conf = self._feed_homo_conf(framework_name)
  76. res_dict = self.pipeline_obj._job_invoker.convert_homo_model(conf)
  77. return res_dict