123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import os
- import re
- from contextlib import closing
- from flow_sdk.client.api.base import BaseFlowAPI
- from flow_sdk.utils import preprocess, get_project_base_directory
- class Model(BaseFlowAPI):
- def load(self, config_data=None, job_id=None):
- if config_data is None and job_id is None:
- return {
- "retcode": 100,
- "retmsg": "Load model failed. No arguments received, "
- "please provide one of arguments from job id and conf path."
- }
- if config_data is not None and job_id is not None:
- return {
- "retcode": 100,
- "retmsg": "Load model failed. Please do not provide job id and "
- "conf path at the same time."
- }
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/load', json=config_data)
- def bind(self, config_data, job_id=None):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/bind', json=config_data)
- def import_model(self, config_data, from_database=False):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- if kwargs.pop('from_database'):
- return self._post(url='model/restore', json=config_data)
- file_path = config_data['file']
- if not os.path.isabs(file_path):
- file_path = os.path.join(get_project_base_directory(), file_path)
- if os.path.exists(file_path):
- FileNotFoundError(
- 'The file is obtained from the fate flow client machine, but it does not exist, '
- ' please check the path: {}'.format(file_path)
- )
- config_data['force_update'] = int(config_data.get('force_update', False))
- files = {'file': open(file_path, 'rb')}
- return self._post(url='model/import', data=config_data, files=files)
- def export_model(self, config_data, to_database=False):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- if not config_data.pop("to_database"):
- with closing(self._get(url='model/export', handle_result=False, json=config_data, stream=True)) as response:
- if response.status_code == 200:
- archive_file_name = re.findall("filename=(.+)", response.headers["Content-Disposition"])[0]
- os.makedirs(config_data["output_path"], exist_ok=True)
- archive_file_path = os.path.join(config_data["output_path"], archive_file_name)
- with open(archive_file_path, 'wb') as fw:
- for chunk in response.iter_content(1024):
- if chunk:
- fw.write(chunk)
- response = {'retcode': 0,
- 'file': archive_file_path,
- 'retmsg': 'download successfully, please check {}'.format(archive_file_path)}
- else:
- response = response.json()
- return response
- return self._post(url='model/store', json=config_data)
- def migrate(self, config_data):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/migrate', json=config_data)
- def tag_model(self, job_id, tag_name, remove=False):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- if not config_data.pop('remove'):
- return self._post(url='model/model_tag/create', json=config_data)
- else:
- return self._post(url='model/model_tag/remove', json=config_data)
- def tag_list(self, job_id):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/model_tag/retrieve', json=config_data)
- def deploy(self, model_id, model_version, cpn_list=None, predict_dsl=None, components_checkpoint=None):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/deploy', json=config_data)
- def get_predict_dsl(self, model_id, model_version):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/get/predict/dsl', json=config_data)
- def get_predict_conf(self, model_id, model_version):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/get/predict/conf', json=config_data)
- def get_model_info(self, model_id=None, model_version=None, role=None, party_id=None, query_filters=None, **kwargs):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/query', json=config_data)
- def homo_convert(self, config_data):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- return self._post(url='model/homo/convert', json=config_data)
- def homo_deploy(self, config_data):
- kwargs = locals()
- config_data, dsl_data = preprocess(**kwargs)
- if config_data.get('deployment_type') == "kfserving":
- kube_config = config_data.get('deployment_parameters', {}).get('config_file')
- if kube_config:
- if not os.path.isabs(kube_config):
- kube_config = os.path.join(get_project_base_directory(), kube_config)
- if os.path.exists(kube_config):
- with open(kube_config, 'r') as fp:
- config_data['deployment_parameters']['config_file_content'] = fp.read()
- del config_data['deployment_parameters']['config_file']
- else:
- raise Exception('The kube_config file is obtained from the fate flow client machine, '
- 'but it does not exist, please check the path: {}'.format(kube_config))
- return self._post(url='model/homo/deploy', json=config_data)
|