model.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import re
  18. from contextlib import closing
  19. from flow_sdk.client.api.base import BaseFlowAPI
  20. from flow_sdk.utils import preprocess, get_project_base_directory
  21. class Model(BaseFlowAPI):
  22. def load(self, config_data=None, job_id=None):
  23. if config_data is None and job_id is None:
  24. return {
  25. "retcode": 100,
  26. "retmsg": "Load model failed. No arguments received, "
  27. "please provide one of arguments from job id and conf path."
  28. }
  29. if config_data is not None and job_id is not None:
  30. return {
  31. "retcode": 100,
  32. "retmsg": "Load model failed. Please do not provide job id and "
  33. "conf path at the same time."
  34. }
  35. kwargs = locals()
  36. config_data, dsl_data = preprocess(**kwargs)
  37. return self._post(url='model/load', json=config_data)
  38. def bind(self, config_data, job_id=None):
  39. kwargs = locals()
  40. config_data, dsl_data = preprocess(**kwargs)
  41. return self._post(url='model/bind', json=config_data)
  42. def import_model(self, config_data, from_database=False):
  43. kwargs = locals()
  44. config_data, dsl_data = preprocess(**kwargs)
  45. if kwargs.pop('from_database'):
  46. return self._post(url='model/restore', json=config_data)
  47. file_path = config_data['file']
  48. if not os.path.isabs(file_path):
  49. file_path = os.path.join(get_project_base_directory(), file_path)
  50. if os.path.exists(file_path):
  51. FileNotFoundError(
  52. 'The file is obtained from the fate flow client machine, but it does not exist, '
  53. ' please check the path: {}'.format(file_path)
  54. )
  55. config_data['force_update'] = int(config_data.get('force_update', False))
  56. files = {'file': open(file_path, 'rb')}
  57. return self._post(url='model/import', data=config_data, files=files)
  58. def export_model(self, config_data, to_database=False):
  59. kwargs = locals()
  60. config_data, dsl_data = preprocess(**kwargs)
  61. if not config_data.pop("to_database"):
  62. with closing(self._get(url='model/export', handle_result=False, json=config_data, stream=True)) as response:
  63. if response.status_code == 200:
  64. archive_file_name = re.findall("filename=(.+)", response.headers["Content-Disposition"])[0]
  65. os.makedirs(config_data["output_path"], exist_ok=True)
  66. archive_file_path = os.path.join(config_data["output_path"], archive_file_name)
  67. with open(archive_file_path, 'wb') as fw:
  68. for chunk in response.iter_content(1024):
  69. if chunk:
  70. fw.write(chunk)
  71. response = {'retcode': 0,
  72. 'file': archive_file_path,
  73. 'retmsg': 'download successfully, please check {}'.format(archive_file_path)}
  74. else:
  75. response = response.json()
  76. return response
  77. return self._post(url='model/store', json=config_data)
  78. def migrate(self, config_data):
  79. kwargs = locals()
  80. config_data, dsl_data = preprocess(**kwargs)
  81. return self._post(url='model/migrate', json=config_data)
  82. def tag_model(self, job_id, tag_name, remove=False):
  83. kwargs = locals()
  84. config_data, dsl_data = preprocess(**kwargs)
  85. if not config_data.pop('remove'):
  86. return self._post(url='model/model_tag/create', json=config_data)
  87. else:
  88. return self._post(url='model/model_tag/remove', json=config_data)
  89. def tag_list(self, job_id):
  90. kwargs = locals()
  91. config_data, dsl_data = preprocess(**kwargs)
  92. return self._post(url='model/model_tag/retrieve', json=config_data)
  93. def deploy(self, model_id, model_version, cpn_list=None, predict_dsl=None, components_checkpoint=None):
  94. kwargs = locals()
  95. config_data, dsl_data = preprocess(**kwargs)
  96. return self._post(url='model/deploy', json=config_data)
  97. def get_predict_dsl(self, model_id, model_version):
  98. kwargs = locals()
  99. config_data, dsl_data = preprocess(**kwargs)
  100. return self._post(url='model/get/predict/dsl', json=config_data)
  101. def get_predict_conf(self, model_id, model_version):
  102. kwargs = locals()
  103. config_data, dsl_data = preprocess(**kwargs)
  104. return self._post(url='model/get/predict/conf', json=config_data)
  105. def get_model_info(self, model_id=None, model_version=None, role=None, party_id=None, query_filters=None, **kwargs):
  106. kwargs = locals()
  107. config_data, dsl_data = preprocess(**kwargs)
  108. return self._post(url='model/query', json=config_data)
  109. def homo_convert(self, config_data):
  110. kwargs = locals()
  111. config_data, dsl_data = preprocess(**kwargs)
  112. return self._post(url='model/homo/convert', json=config_data)
  113. def homo_deploy(self, config_data):
  114. kwargs = locals()
  115. config_data, dsl_data = preprocess(**kwargs)
  116. if config_data.get('deployment_type') == "kfserving":
  117. kube_config = config_data.get('deployment_parameters', {}).get('config_file')
  118. if kube_config:
  119. if not os.path.isabs(kube_config):
  120. kube_config = os.path.join(get_project_base_directory(), kube_config)
  121. if os.path.exists(kube_config):
  122. with open(kube_config, 'r') as fp:
  123. config_data['deployment_parameters']['config_file_content'] = fp.read()
  124. del config_data['deployment_parameters']['config_file']
  125. else:
  126. raise Exception('The kube_config file is obtained from the fate flow client machine, '
  127. 'but it does not exist, please check the path: {}'.format(kube_config))
  128. return self._post(url='model/homo/deploy', json=config_data)