fate_flow_client.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import sys
  17. import argparse
  18. import json
  19. import os
  20. import tarfile
  21. import traceback
  22. from contextlib import closing
  23. import time
  24. import re
  25. import requests
  26. from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
  27. # be sure to import environment variable before importing fate_arch
  28. from fate_flow import set_env
  29. from fate_arch.common import file_utils
  30. from fate_flow.settings import API_VERSION, HOST, HTTP_PORT
  31. from fate_flow.utils import detect_utils, requests_utils
  32. from fate_flow.utils.base_utils import get_fate_flow_directory
  33. JOB_OPERATE_FUNC = ["submit_job", "stop_job", "query_job", "data_view_query", "clean_job", "clean_queue"]
  34. JOB_FUNC = ["job_config", "job_log_download"]
  35. TASK_OPERATE_FUNC = ["query_task"]
  36. TRACKING_FUNC = ["component_parameters", "component_metric_all", "component_metric_delete", "component_metrics",
  37. "component_output_model", "component_output_data", "component_output_data_table"]
  38. DATA_FUNC = ["download", "upload", "upload_history"]
  39. TABLE_FUNC = ["table_info", "table_delete", "table_add", "table_bind"]
  40. MODEL_FUNC = ["load", "bind", "store", "restore", "export", "import"]
  41. PERMISSION_FUNC = ["grant_privilege", "delete_privilege", "query_privilege"]
  42. def prettify(response, verbose=True):
  43. if verbose:
  44. print(json.dumps(response, indent=4, ensure_ascii=False))
  45. print()
  46. return response
  47. def call_fun(func, config_data, dsl_path, config_path):
  48. server_url = "http://{}:{}/{}".format(HOST, HTTP_PORT, API_VERSION)
  49. response = None
  50. if func in JOB_OPERATE_FUNC:
  51. if func == 'submit_job':
  52. if not config_path:
  53. raise Exception('the following arguments are required: {}'.format('runtime conf path'))
  54. if not dsl_path and config_data.get('job_parameters', {}).get('job_type', '') == 'predict':
  55. raise Exception('for train job, the following arguments are required: {}'.format('dsl path'))
  56. dsl_data = {}
  57. if dsl_path:
  58. dsl_path = os.path.abspath(dsl_path)
  59. with open(dsl_path, 'r') as f:
  60. dsl_data = json.load(f)
  61. post_data = {'job_dsl': dsl_data,
  62. 'job_runtime_conf': config_data}
  63. response = requests_utils.request(method="post", url="/".join([server_url, "job", func.rstrip('_job')]), json=post_data)
  64. try:
  65. if response.json()['retcode'] == 999:
  66. start_cluster_standalone_job_server()
  67. response = requests_utils.request(method="post", url="/".join([server_url, "job", func.rstrip('_job')]), json=post_data)
  68. except:
  69. pass
  70. elif func == 'data_view_query' or func == 'clean_queue':
  71. response = requests_utils.request(method="post", url="/".join([server_url, "job", func.replace('_', '/')]), json=config_data)
  72. else:
  73. if func != 'query_job':
  74. detect_utils.check_config(config=config_data, required_arguments=['job_id'])
  75. post_data = config_data
  76. response = requests_utils.request(method="post", url="/".join([server_url, "job", func.rstrip('_job')]), json=post_data)
  77. if func == 'query_job':
  78. response = response.json()
  79. if response['retcode'] == 0:
  80. for i in range(len(response['data'])):
  81. del response['data'][i]['f_runtime_conf']
  82. del response['data'][i]['f_dsl']
  83. elif func in JOB_FUNC:
  84. if func == 'job_config':
  85. detect_utils.check_config(config=config_data, required_arguments=['job_id', 'role', 'party_id', 'output_path'])
  86. response = requests_utils.request(method="post", url="/".join([server_url, func.replace('_', '/')]), json=config_data)
  87. response_data = response.json()
  88. if response_data['retcode'] == 0:
  89. job_id = response_data['data']['job_id']
  90. download_directory = os.path.join(config_data['output_path'], 'job_{}_config'.format(job_id))
  91. os.makedirs(download_directory, exist_ok=True)
  92. for k, v in response_data['data'].items():
  93. if k == 'job_id':
  94. continue
  95. with open('{}/{}.json'.format(download_directory, k), 'w') as fw:
  96. json.dump(v, fw, indent=4)
  97. del response_data['data']['dsl']
  98. del response_data['data']['runtime_conf']
  99. response_data['directory'] = download_directory
  100. response_data['retmsg'] = 'download successfully, please check {} directory'.format(download_directory)
  101. response = response_data
  102. elif func == 'job_log_download':
  103. detect_utils.check_config(config=config_data, required_arguments=['job_id', 'output_path'])
  104. job_id = config_data['job_id']
  105. tar_file_name = 'job_{}_log.tar.gz'.format(job_id)
  106. extract_dir = os.path.join(config_data['output_path'], 'job_{}_log'.format(job_id))
  107. with closing(requests_utils.request(method="post", url="/".join([server_url, func.replace('_', '/')]), json=config_data,
  108. stream=True)) as response:
  109. if response.status_code == 200:
  110. download_from_request(http_response=response, tar_file_name=tar_file_name, extract_dir=extract_dir)
  111. response = {'retcode': 0,
  112. 'directory': extract_dir,
  113. 'retmsg': 'download successfully, please check {} directory'.format(extract_dir)}
  114. else:
  115. response = response.json()
  116. elif func in TASK_OPERATE_FUNC:
  117. response = requests_utils.request(method="post", url="/".join([server_url, "job", "task", func.rstrip('_task')]), json=config_data)
  118. elif func in TRACKING_FUNC:
  119. if func != 'component_metric_delete':
  120. detect_utils.check_config(config=config_data,
  121. required_arguments=['job_id', 'component_name', 'role', 'party_id'])
  122. if func == 'component_output_data':
  123. detect_utils.check_config(config=config_data, required_arguments=['output_path'])
  124. tar_file_name = 'job_{}_{}_{}_{}_output_data.tar.gz'.format(config_data['job_id'],
  125. config_data['component_name'],
  126. config_data['role'],
  127. config_data['party_id'])
  128. extract_dir = os.path.join(config_data['output_path'], tar_file_name.replace('.tar.gz', ''))
  129. with closing(requests_utils.request(method="get", url="/".join([server_url, "tracking", func.replace('_', '/'), 'download']),
  130. json=config_data, stream=True)) as response:
  131. if response.status_code == 200:
  132. try:
  133. download_from_request(http_response=response, tar_file_name=tar_file_name, extract_dir=extract_dir)
  134. response = {'retcode': 0,
  135. 'directory': extract_dir,
  136. 'retmsg': 'download successfully, please check {} directory'.format(extract_dir)}
  137. except:
  138. response = {'retcode': 100,
  139. 'retmsg': 'download failed, please check if the parameters are correct'}
  140. else:
  141. response = response.json()
  142. else:
  143. response = requests_utils.request(method="post", url="/".join([server_url, "tracking", func.replace('_', '/')]), json=config_data)
  144. elif func in DATA_FUNC:
  145. if func == 'upload' and config_data.get('use_local_data', 1) != 0:
  146. file_name = config_data.get('file')
  147. if not os.path.isabs(file_name):
  148. file_name = os.path.join(get_fate_flow_directory(), file_name)
  149. if os.path.exists(file_name):
  150. with open(file_name, 'rb') as fp:
  151. data = MultipartEncoder(
  152. fields={'file': (os.path.basename(file_name), fp, 'application/octet-stream')}
  153. )
  154. tag = [0]
  155. def read_callback(monitor):
  156. if config_data.get('verbose') == 1:
  157. sys.stdout.write("\r UPLOADING:{0}{1}".format("|" * (monitor.bytes_read * 100 // monitor.len), '%.2f%%' % (monitor.bytes_read * 100 // monitor.len)))
  158. sys.stdout.flush()
  159. if monitor.bytes_read /monitor.len == 1:
  160. tag[0] += 1
  161. if tag[0] == 2:
  162. sys.stdout.write('\n')
  163. data = MultipartEncoderMonitor(data, read_callback)
  164. response = requests_utils.request(method="post", url="/".join([server_url, "data", func.replace('_', '/')]), data=data,
  165. params=json.dumps(config_data), headers={'Content-Type': data.content_type})
  166. else:
  167. raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
  168. 'please check the path: {}'.format(file_name))
  169. else:
  170. response = requests_utils.request(method="post", url="/".join([server_url, "data", func.replace('_', '/')]), json=config_data)
  171. try:
  172. if response.json()['retcode'] == 999:
  173. start_cluster_standalone_job_server()
  174. response = requests_utils.request(method="post", url="/".join([server_url, "data", func]), json=config_data)
  175. except:
  176. pass
  177. elif func in TABLE_FUNC:
  178. if func == "table_info":
  179. detect_utils.check_config(config=config_data, required_arguments=['namespace', 'table_name'])
  180. response = requests_utils.request(method="post", url="/".join([server_url, "table", func]), json=config_data)
  181. else:
  182. response = requests_utils.request(method="post", url="/".join([server_url, func.replace('_', '/')]), json=config_data)
  183. elif func in MODEL_FUNC:
  184. if func == "import":
  185. file_path = config_data["file"]
  186. if not os.path.isabs(file_path):
  187. file_path = os.path.join(get_fate_flow_directory(), file_path)
  188. if os.path.exists(file_path):
  189. files = {'file': open(file_path, 'rb')}
  190. else:
  191. raise Exception('The file is obtained from the fate flow client machine, but it does not exist, '
  192. 'please check the path: {}'.format(file_path))
  193. response = requests_utils.request(method="post", url="/".join([server_url, "model", func]), data=config_data, files=files)
  194. elif func == "export":
  195. with closing(requests_utils.request(method="get", url="/".join([server_url, "model", func]), json=config_data, stream=True)) as response:
  196. if response.status_code == 200:
  197. archive_file_name = re.findall("filename=(.+)", response.headers["Content-Disposition"])[0]
  198. os.makedirs(config_data["output_path"], exist_ok=True)
  199. archive_file_path = os.path.join(config_data["output_path"], archive_file_name)
  200. with open(archive_file_path, 'wb') as fw:
  201. for chunk in response.iter_content(1024):
  202. if chunk:
  203. fw.write(chunk)
  204. response = {'retcode': 0,
  205. 'file': archive_file_path,
  206. 'retmsg': 'download successfully, please check {}'.format(archive_file_path)}
  207. else:
  208. response = response.json()
  209. else:
  210. response = requests_utils.request(method="post", url="/".join([server_url, "model", func]), json=config_data)
  211. elif func in PERMISSION_FUNC:
  212. detect_utils.check_config(config=config_data, required_arguments=['src_party_id', 'src_role'])
  213. response = requests_utils.request(method="post", url="/".join([server_url, "permission", func.replace('_', '/')]), json=config_data)
  214. return response.json() if isinstance(response, requests.models.Response) else response
  215. def download_from_request(http_response, tar_file_name, extract_dir):
  216. with open(tar_file_name, 'wb') as fw:
  217. for chunk in http_response.iter_content(1024):
  218. if chunk:
  219. fw.write(chunk)
  220. tar = tarfile.open(tar_file_name, "r:gz")
  221. file_names = tar.getnames()
  222. for file_name in file_names:
  223. tar.extract(file_name, extract_dir)
  224. tar.close()
  225. os.remove(tar_file_name)
  226. def start_cluster_standalone_job_server():
  227. print('use service.sh to start standalone node server....')
  228. os.system('sh service.sh start --standalone_node')
  229. time.sleep(5)
  230. if __name__ == "__main__":
  231. parser = argparse.ArgumentParser()
  232. parser.add_argument('-c', '--config', required=False, type=str, help="runtime conf path")
  233. parser.add_argument('-d', '--dsl', required=False, type=str, help="dsl path")
  234. parser.add_argument('-f', '--function', type=str,
  235. choices=(
  236. DATA_FUNC + MODEL_FUNC + JOB_FUNC + JOB_OPERATE_FUNC + TASK_OPERATE_FUNC + TABLE_FUNC +
  237. TRACKING_FUNC + PERMISSION_FUNC),
  238. required=True,
  239. help="function to call")
  240. parser.add_argument('-j', '--job_id', required=False, type=str, help="job id")
  241. parser.add_argument('-p', '--party_id', required=False, type=str, help="party id")
  242. parser.add_argument('-r', '--role', required=False, type=str, help="role")
  243. parser.add_argument('-cpn', '--component_name', required=False, type=str, help="component name")
  244. parser.add_argument('-s', '--status', required=False, type=str, help="status")
  245. parser.add_argument('-n', '--namespace', required=False, type=str, help="namespace")
  246. parser.add_argument('-t', '--table_name', required=False, type=str, help="table name")
  247. parser.add_argument('-w', '--work_mode', required=False, type=int, help="work mode")
  248. parser.add_argument('-i', '--file', required=False, type=str, help="file")
  249. parser.add_argument('-o', '--output_path', required=False, type=str, help="output_path")
  250. parser.add_argument('-m', '--model', required=False, type=str, help="TrackingMetric model id")
  251. parser.add_argument('-drop', '--drop', required=False, type=str, help="drop data table")
  252. parser.add_argument('-limit', '--limit', required=False, type=int, help="limit number")
  253. parser.add_argument('-verbose', '--verbose', required=False, type=int, help="number 0 or 1")
  254. parser.add_argument('-src_party_id', '--src_party_id', required=False, type=str, help="src party id")
  255. parser.add_argument('-src_role', '--src_role', required=False, type=str, help="src role")
  256. parser.add_argument('-privilege_role', '--privilege_role', required=False, type=str, help="privilege role")
  257. parser.add_argument('-privilege_command', '--privilege_command', required=False, type=str, help="privilege command")
  258. parser.add_argument('-privilege_component', '--privilege_component', required=False, type=str, help="privilege component")
  259. try:
  260. args = parser.parse_args()
  261. config_data = {}
  262. dsl_path = args.dsl
  263. config_path = args.config
  264. if args.config:
  265. args.config = os.path.abspath(args.config)
  266. with open(args.config, 'r') as f:
  267. config_data = json.load(f)
  268. config_data.update(dict((k, v) for k, v in vars(args).items() if v is not None))
  269. if args.party_id or args.role:
  270. config_data['local'] = config_data.get('local', {})
  271. if args.party_id:
  272. config_data['local']['party_id'] = args.party_id
  273. if args.role:
  274. config_data['local']['role'] = args.role
  275. if config_data.get('output_path'):
  276. config_data['output_path'] = os.path.abspath(config_data["output_path"])
  277. response = call_fun(args.function, config_data, dsl_path, config_path)
  278. except Exception as e:
  279. exc_type, exc_value, exc_traceback_obj = sys.exc_info()
  280. response = {'retcode': 100, 'retmsg': str(e), 'traceback': traceback.format_exception(exc_type, exc_value, exc_traceback_obj)}
  281. if 'Connection refused' in str(e):
  282. response['retmsg'] = 'Connection refused, Please check if the fate flow service is started'
  283. del response['traceback']
  284. response_dict = prettify(response)