123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import sys
- import argparse
- import json
- import os
- import tarfile
- import traceback
- from contextlib import closing
- import time
- import re
- import requests
- from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
- # be sure to import environment variable before importing fate_arch
- from fate_flow import set_env
- from fate_arch.common import file_utils
- from fate_flow.settings import API_VERSION, HOST, HTTP_PORT
- from fate_flow.utils import detect_utils, requests_utils
- from fate_flow.utils.base_utils import get_fate_flow_directory
- JOB_OPERATE_FUNC = ["submit_job", "stop_job", "query_job", "data_view_query", "clean_job", "clean_queue"]
- JOB_FUNC = ["job_config", "job_log_download"]
- TASK_OPERATE_FUNC = ["query_task"]
- TRACKING_FUNC = ["component_parameters", "component_metric_all", "component_metric_delete", "component_metrics",
- "component_output_model", "component_output_data", "component_output_data_table"]
- DATA_FUNC = ["download", "upload", "upload_history"]
- TABLE_FUNC = ["table_info", "table_delete", "table_add", "table_bind"]
- MODEL_FUNC = ["load", "bind", "store", "restore", "export", "import"]
- PERMISSION_FUNC = ["grant_privilege", "delete_privilege", "query_privilege"]
- def prettify(response, verbose=True):
- if verbose:
- print(json.dumps(response, indent=4, ensure_ascii=False))
- print()
- return response
- def call_fun(func, config_data, dsl_path, config_path):
- server_url = "http://{}:{}/{}".format(HOST, HTTP_PORT, API_VERSION)
- response = None
- if func in JOB_OPERATE_FUNC:
- if func == 'submit_job':
- if not config_path:
- raise Exception('the following arguments are required: {}'.format('runtime conf path'))
- if not dsl_path and config_data.get('job_parameters', {}).get('job_type', '') == 'predict':
- raise Exception('for train job, the following arguments are required: {}'.format('dsl path'))
- dsl_data = {}
- if dsl_path:
- dsl_path = os.path.abspath(dsl_path)
- with open(dsl_path, 'r') as f:
- dsl_data = json.load(f)
- post_data = {'job_dsl': dsl_data,
- 'job_runtime_conf': config_data}
- response = requests_utils.request(method="post", url="/".join([server_url, "job", func.rstrip('_job')]), json=post_data)
- try:
- if response.json()['retcode'] == 999:
- start_cluster_standalone_job_server()
- response = requests_utils.request(method="post", url="/".join([server_url, "job", func.rstrip('_job')]), json=post_data)
- except:
- pass
- elif func == 'data_view_query' or func == 'clean_queue':
- response = requests_utils.request(method="post", url="/".join([server_url, "job", func.replace('_', '/')]), json=config_data)
- else:
- if func != 'query_job':
- detect_utils.check_config(config=config_data, required_arguments=['job_id'])
- post_data = config_data
- response = requests_utils.request(method="post", url="/".join([server_url, "job", func.rstrip('_job')]), json=post_data)
- if func == 'query_job':
- response = response.json()
- if response['retcode'] == 0:
- for i in range(len(response['data'])):
- del response['data'][i]['f_runtime_conf']
- del response['data'][i]['f_dsl']
- elif func in JOB_FUNC:
- if func == 'job_config':
- detect_utils.check_config(config=config_data, required_arguments=['job_id', 'role', 'party_id', 'output_path'])
- response = requests_utils.request(method="post", url="/".join([server_url, func.replace('_', '/')]), json=config_data)
- response_data = response.json()
- if response_data['retcode'] == 0:
- job_id = response_data['data']['job_id']
- download_directory = os.path.join(config_data['output_path'], 'job_{}_config'.format(job_id))
- os.makedirs(download_directory, exist_ok=True)
- for k, v in response_data['data'].items():
- if k == 'job_id':
- continue
- with open('{}/{}.json'.format(download_directory, k), 'w') as fw:
- json.dump(v, fw, indent=4)
- del response_data['data']['dsl']
- del response_data['data']['runtime_conf']
- response_data['directory'] = download_directory
- response_data['retmsg'] = 'download successfully, please check {} directory'.format(download_directory)
- response = response_data
- elif func == 'job_log_download':
- detect_utils.check_config(config=config_data, required_arguments=['job_id', 'output_path'])
- job_id = config_data['job_id']
- tar_file_name = 'job_{}_log.tar.gz'.format(job_id)
- extract_dir = os.path.join(config_data['output_path'], 'job_{}_log'.format(job_id))
- with closing(requests_utils.request(method="post", url="/".join([server_url, func.replace('_', '/')]), json=config_data,
- stream=True)) as response:
- if response.status_code == 200:
- download_from_request(http_response=response, tar_file_name=tar_file_name, extract_dir=extract_dir)
- response = {'retcode': 0,
- 'directory': extract_dir,
- 'retmsg': 'download successfully, please check {} directory'.format(extract_dir)}
- else:
- response = response.json()
- elif func in TASK_OPERATE_FUNC:
- response = requests_utils.request(method="post", url="/".join([server_url, "job", "task", func.rstrip('_task')]), json=config_data)
- elif func in TRACKING_FUNC:
- if func != 'component_metric_delete':
- detect_utils.check_config(config=config_data,
- required_arguments=['job_id', 'component_name', 'role', 'party_id'])
- if func == 'component_output_data':
- detect_utils.check_config(config=config_data, required_arguments=['output_path'])
- tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(config_data['job_id'],
- config_data['component_name'],
- config_data['role'],
- config_data['party_id'])
- extract_dir = os.path.join(config_data['output_path'], tar_file_name.replace('.tar.gz', ''))
- with closing(requests_utils.request(method="get", url="/".join([server_url, "tracking", func.replace('_', '/'), 'download']),
- json=config_data, stream=True)) as response:
- if response.status_code == 200:
- try:
- download_from_request(http_response=response, tar_file_name=tar_file_name, extract_dir=extract_dir)
- response = {'retcode': 0,
- 'directory': extract_dir,
- 'retmsg': 'download successfully, please check {} directory'.format(extract_dir)}
- except:
- response = {'retcode': 100,
- 'retmsg': 'download failed, please check if the parameters are correct'}
- else:
- response = response.json()
- else:
- response = requests_utils.request(method="post", url="/".join([server_url, "tracking", func.replace('_', '/')]), json=config_data)
- elif func in DATA_FUNC:
- if func == 'upload' and config_data.get('use_local_data', 1) != 0:
- file_name = config_data.get('file')
- if not os.path.isabs(file_name):
- file_name = os.path.join(get_fate_flow_directory(), file_name)
- if os.path.exists(file_name):
- with open(file_name, 'rb') as fp:
- data = MultipartEncoder(
- fields={'file': (os.path.basename(file_name), fp, 'application/octet-stream')}
- )
- tag = [0]
- def read_callback(monitor):
- if config_data.get('verbose') == 1:
- sys.stdout.write("\r UPLOADING:{0}{1}".format("|" * (monitor.bytes_read * 100 // monitor.len), '%.2f%%' % (monitor.bytes_read * 100 // monitor.len)))
- sys.stdout.flush()
- if monitor.bytes_read /monitor.len == 1:
- tag[0] += 1
- if tag[0] == 2:
- sys.stdout.write('\n')
- data = MultipartEncoderMonitor(data, read_callback)
- response = requests_utils.request(method="post", url="/".join([server_url, "data", func.replace('_', '/')]), data=data,
- params=json.dumps(config_data), headers={'Content-Type': data.content_type})
- else:
- raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
- 'please check the path: {}'.format(file_name))
- else:
- response = requests_utils.request(method="post", url="/".join([server_url, "data", func.replace('_', '/')]), json=config_data)
- try:
- if response.json()['retcode'] == 999:
- start_cluster_standalone_job_server()
- response = requests_utils.request(method="post", url="/".join([server_url, "data", func]), json=config_data)
- except:
- pass
- elif func in TABLE_FUNC:
- if func == "table_info":
- detect_utils.check_config(config=config_data, required_arguments=['namespace', 'table_name'])
- response = requests_utils.request(method="post", url="/".join([server_url, "table", func]), json=config_data)
- else:
- response = requests_utils.request(method="post", url="/".join([server_url, func.replace('_', '/')]), json=config_data)
- elif func in MODEL_FUNC:
- if func == "import":
- file_path = config_data["file"]
- if not os.path.isabs(file_path):
- file_path = os.path.join(get_fate_flow_directory(), file_path)
- if os.path.exists(file_path):
- files = {'file': open(file_path, 'rb')}
- else:
- raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
- 'please check the path: {}'.format(file_path))
- response = requests_utils.request(method="post", url="/".join([server_url, "model", func]), data=config_data, files=files)
- elif func == "export":
- with closing(requests_utils.request(method="get", url="/".join([server_url, "model", func]), json=config_data, stream=True)) as response:
- if response.status_code == 200:
- archive_file_name = re.findall("filename=(.+)", response.headers["Content-Disposition"])[0]
- os.makedirs(config_data["output_path"], exist_ok=True)
- archive_file_path = os.path.join(config_data["output_path"], archive_file_name)
- with open(archive_file_path, 'wb') as fw:
- for chunk in response.iter_content(1024):
- if chunk:
- fw.write(chunk)
- response = {'retcode': 0,
- 'file': archive_file_path,
- 'retmsg': 'download successfully, please check {}'.format(archive_file_path)}
- else:
- response = response.json()
- else:
- response = requests_utils.request(method="post", url="/".join([server_url, "model", func]), json=config_data)
- elif func in PERMISSION_FUNC:
- detect_utils.check_config(config=config_data, required_arguments=['src_party_id', 'src_role'])
- response = requests_utils.request(method="post", url="/".join([server_url, "permission", func.replace('_', '/')]), json=config_data)
- return response.json() if isinstance(response, requests.models.Response) else response
- def download_from_request(http_response, tar_file_name, extract_dir):
- with open(tar_file_name, 'wb') as fw:
- for chunk in http_response.iter_content(1024):
- if chunk:
- fw.write(chunk)
- tar = tarfile.open(tar_file_name, "r:gz")
- file_names = tar.getnames()
- for file_name in file_names:
- tar.extract(file_name, extract_dir)
- tar.close()
- os.remove(tar_file_name)
- def start_cluster_standalone_job_server():
- print('use service.sh to start standalone node server....')
- os.system('sh service.sh start --standalone_node')
- time.sleep(5)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('-c', '--config', required=False, type=str, help="runtime conf path")
- parser.add_argument('-d', '--dsl', required=False, type=str, help="dsl path")
- parser.add_argument('-f', '--function', type=str,
- choices=(
- DATA_FUNC + MODEL_FUNC + JOB_FUNC + JOB_OPERATE_FUNC + TASK_OPERATE_FUNC + TABLE_FUNC +
- TRACKING_FUNC + PERMISSION_FUNC),
- required=True,
- help="function to call")
- parser.add_argument('-j', '--job_id', required=False, type=str, help="job id")
- parser.add_argument('-p', '--party_id', required=False, type=str, help="party id")
- parser.add_argument('-r', '--role', required=False, type=str, help="role")
- parser.add_argument('-cpn', '--component_name', required=False, type=str, help="component name")
- parser.add_argument('-s', '--status', required=False, type=str, help="status")
- parser.add_argument('-n', '--namespace', required=False, type=str, help="namespace")
- parser.add_argument('-t', '--table_name', required=False, type=str, help="table name")
- parser.add_argument('-w', '--work_mode', required=False, type=int, help="work mode")
- parser.add_argument('-i', '--file', required=False, type=str, help="file")
- parser.add_argument('-o', '--output_path', required=False, type=str, help="output_path")
- parser.add_argument('-m', '--model', required=False, type=str, help="TrackingMetric model id")
- parser.add_argument('-drop', '--drop', required=False, type=str, help="drop data table")
- parser.add_argument('-limit', '--limit', required=False, type=int, help="limit number")
- parser.add_argument('-verbose', '--verbose', required=False, type=int, help="number 0 or 1")
- parser.add_argument('-src_party_id', '--src_party_id', required=False, type=str, help="src party id")
- parser.add_argument('-src_role', '--src_role', required=False, type=str, help="src role")
- parser.add_argument('-privilege_role', '--privilege_role', required=False, type=str, help="privilege role")
- parser.add_argument('-privilege_command', '--privilege_command', required=False, type=str, help="privilege command")
- parser.add_argument('-privilege_component', '--privilege_component', required=False, type=str, help="privilege component")
- try:
- args = parser.parse_args()
- config_data = {}
- dsl_path = args.dsl
- config_path = args.config
- if args.config:
- args.config = os.path.abspath(args.config)
- with open(args.config, 'r') as f:
- config_data = json.load(f)
- config_data.update(dict((k, v) for k, v in vars(args).items() if v is not None))
- if args.party_id or args.role:
- config_data['local'] = config_data.get('local', {})
- if args.party_id:
- config_data['local']['party_id'] = args.party_id
- if args.role:
- config_data['local']['role'] = args.role
- if config_data.get('output_path'):
- config_data['output_path'] = os.path.abspath(config_data["output_path"])
- response = call_fun(args.function, config_data, dsl_path, config_path)
- except Exception as e:
- exc_type, exc_value, exc_traceback_obj = sys.exc_info()
- response = {'retcode': 100, 'retmsg': str(e), 'traceback': traceback.format_exception(exc_type, exc_value, exc_traceback_obj)}
- if 'Connection refused' in str(e):
- response['retmsg'] = 'Connection refused, Please check if the fate flow service is started'
- del response['traceback']
- response_dict = prettify(response)
|