api_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import random
  18. import time
  19. from functools import wraps
  20. from io import BytesIO
  21. from flask import (
  22. Response, jsonify, send_file,
  23. request as flask_request,
  24. )
  25. from werkzeug.http import HTTP_STATUS_CODES
  26. from fate_arch.common import (
  27. CoordinationCommunicationProtocol, CoordinationProxyService,
  28. FederatedMode,
  29. )
  30. from fate_arch.common.base_utils import json_dumps, json_loads
  31. from fate_arch.common.versions import get_fate_version
  32. from fate_flow.db.job_default_config import JobDefaultConfig
  33. from fate_flow.db.runtime_config import RuntimeConfig
  34. from fate_flow.db.service_registry import ServerRegistry
  35. from fate_flow.entity import RetCode
  36. from fate_flow.hook import HookManager
  37. from fate_flow.hook.common.parameters import SignatureParameters
  38. from fate_flow.settings import (
  39. API_VERSION, FATE_FLOW_SERVICE_NAME, HOST, HTTP_PORT,
  40. PARTY_ID, PERMISSION_SWITCH, PROXY, PROXY_PROTOCOL,
  41. REQUEST_MAX_WAIT_SEC, REQUEST_TRY_TIMES, REQUEST_WAIT_SEC,
  42. SITE_AUTHENTICATION, stat_logger,
  43. )
  44. from fate_flow.utils.base_utils import compare_version
  45. from fate_flow.utils.grpc_utils import (
  46. forward_grpc_packet, gen_routing_metadata,
  47. get_command_federation_channel, wrap_grpc_packet,
  48. )
  49. from fate_flow.utils.log_utils import audit_logger, schedule_logger
  50. from fate_flow.utils.permission_utils import get_permission_parameters
  51. from fate_flow.utils.requests_utils import request
  52. fate_version = get_fate_version() or ''
  53. request_headers = {
  54. 'User-Agent': f'{FATE_FLOW_SERVICE_NAME}/{fate_version}',
  55. 'service': FATE_FLOW_SERVICE_NAME,
  56. 'src_fate_ver': fate_version,
  57. }
  58. def get_exponential_backoff_interval(retries, full_jitter=False):
  59. """Calculate the exponential backoff wait time."""
  60. # Will be zero if factor equals 0
  61. countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
  62. # Full jitter according to
  63. # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
  64. if full_jitter:
  65. countdown = random.randrange(countdown + 1)
  66. # Adjust according to maximum wait time and account for negative values.
  67. return max(0, countdown)
  68. def get_json_result(retcode=RetCode.SUCCESS, retmsg='success', data=None, job_id=None, meta=None):
  69. result_dict = {
  70. "retcode": retcode,
  71. "retmsg": retmsg,
  72. "data": data,
  73. "jobId": job_id,
  74. "meta": meta,
  75. }
  76. response = {}
  77. for key, value in result_dict.items():
  78. if value is not None:
  79. response[key] = value
  80. return jsonify(response)
  81. def server_error_response(e):
  82. stat_logger.exception(e)
  83. if len(e.args) > 1:
  84. return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e.args[0]), data=e.args[1])
  85. return get_json_result(retcode=RetCode.EXCEPTION_ERROR, retmsg=repr(e))
  86. def error_response(response_code, retmsg=None):
  87. if retmsg is None:
  88. retmsg = HTTP_STATUS_CODES.get(response_code, 'Unknown Error')
  89. return Response(json.dumps({
  90. 'retmsg': retmsg,
  91. 'retcode': response_code,
  92. }), status=response_code, mimetype='application/json')
  93. def federated_api(job_id, method, endpoint, src_party_id, dest_party_id, src_role, json_body, federated_mode):
  94. src_party_id = str(src_party_id or '')
  95. dest_party_id = str(dest_party_id or '')
  96. src_role = src_role or ''
  97. headers = request_headers.copy()
  98. headers.update({
  99. 'src_party_id': src_party_id,
  100. 'dest_party_id': dest_party_id,
  101. 'src_role': src_role,
  102. })
  103. if SITE_AUTHENTICATION:
  104. sign_obj = HookManager.site_signature(SignatureParameters(PARTY_ID, json_body))
  105. headers['site_signature'] = sign_obj.site_signature or ''
  106. kwargs = {
  107. 'job_id': job_id,
  108. 'method': method,
  109. 'endpoint': endpoint,
  110. 'src_party_id': src_party_id,
  111. 'dest_party_id': dest_party_id,
  112. 'src_role': src_role,
  113. 'json_body': json_body,
  114. 'headers': headers,
  115. }
  116. if federated_mode == FederatedMode.SINGLE or kwargs['dest_party_id'] == '0':
  117. kwargs.update({
  118. 'host': RuntimeConfig.JOB_SERVER_HOST,
  119. 'port': RuntimeConfig.HTTP_PORT,
  120. })
  121. return federated_coordination_on_http(**kwargs)
  122. if federated_mode == FederatedMode.MULTIPLE:
  123. host, port, protocol = get_federated_proxy_address(kwargs['src_party_id'], kwargs['dest_party_id'])
  124. kwargs.update({
  125. 'host': host,
  126. 'port': port,
  127. })
  128. if protocol == CoordinationCommunicationProtocol.HTTP:
  129. return federated_coordination_on_http(**kwargs)
  130. if protocol == CoordinationCommunicationProtocol.GRPC:
  131. return federated_coordination_on_grpc(**kwargs)
  132. raise Exception(f'{protocol} coordination communication protocol is not supported.')
  133. raise Exception(f'{federated_mode} work mode is not supported')
  134. def local_api(job_id, method, endpoint, json_body):
  135. return federated_api(
  136. job_id=job_id, method=method, endpoint=endpoint, json_body=json_body,
  137. src_party_id=PARTY_ID, dest_party_id=PARTY_ID, src_role='',
  138. federated_mode=FederatedMode.SINGLE,
  139. )
  140. def cluster_api(method, host, port, endpoint, json_body, headers=None):
  141. return federated_coordination_on_http(
  142. job_id='', method=method, host=host, port=port, endpoint=endpoint,
  143. json_body=json_body, headers=headers or request_headers.copy(),
  144. )
  145. def get_federated_proxy_address(src_party_id, dest_party_id):
  146. src_party_id = str(src_party_id)
  147. dest_party_id = str(dest_party_id)
  148. if PROXY_PROTOCOL == "default":
  149. protocol = CoordinationCommunicationProtocol.HTTP
  150. else:
  151. protocol = PROXY_PROTOCOL
  152. if isinstance(PROXY, dict):
  153. proxy_name = PROXY.get("name", CoordinationProxyService.FATEFLOW)
  154. if proxy_name == CoordinationProxyService.FATEFLOW and src_party_id == dest_party_id:
  155. host = RuntimeConfig.JOB_SERVER_HOST
  156. port = RuntimeConfig.HTTP_PORT
  157. else:
  158. host = PROXY["host"]
  159. port = PROXY[f"{protocol}_port"]
  160. return (
  161. host,
  162. port,
  163. protocol,
  164. )
  165. if PROXY == CoordinationProxyService.ROLLSITE:
  166. proxy_address = ServerRegistry.FATE_ON_EGGROLL[CoordinationProxyService.ROLLSITE]
  167. return (
  168. proxy_address["host"],
  169. proxy_address.get("grpc_port", proxy_address["port"]),
  170. CoordinationCommunicationProtocol.GRPC,
  171. )
  172. if PROXY == CoordinationProxyService.NGINX:
  173. proxy_address = ServerRegistry.FATE_ON_SPARK[CoordinationProxyService.NGINX]
  174. return (
  175. proxy_address["host"],
  176. proxy_address[f"{protocol}_port"],
  177. protocol,
  178. )
  179. raise RuntimeError(f"can not support coordinate proxy {PROXY}")
  180. def federated_coordination_on_http(
  181. job_id, method, host, port, endpoint,
  182. json_body, headers, **_,
  183. ):
  184. url = f'http://{host}:{port}/{API_VERSION}{endpoint}'
  185. timeout = JobDefaultConfig.remote_request_timeout or 0
  186. timeout = timeout / 1000 or None
  187. for t in range(REQUEST_TRY_TIMES):
  188. try:
  189. response = request(
  190. method=method, url=url, timeout=timeout,
  191. headers=headers, json=json_body,
  192. )
  193. response.raise_for_status()
  194. except Exception as e:
  195. schedule_logger(job_id).warning(f'http api error: {url}\n{e}')
  196. if t >= REQUEST_TRY_TIMES - 1:
  197. raise e
  198. else:
  199. audit_logger(job_id).info(f'http api response: {url}\n{response.text}')
  200. return response.json()
  201. time.sleep(get_exponential_backoff_interval(t))
  202. def federated_coordination_on_grpc(
  203. job_id, method, host, port, endpoint,
  204. src_party_id, dest_party_id,
  205. json_body, headers, **_,
  206. ):
  207. endpoint = f"/{API_VERSION}{endpoint}"
  208. timeout = JobDefaultConfig.remote_request_timeout or 0
  209. _packet = wrap_grpc_packet(
  210. json_body=json_body, http_method=method, url=endpoint,
  211. src_party_id=src_party_id, dst_party_id=dest_party_id,
  212. job_id=job_id, headers=headers, overall_timeout=timeout,
  213. )
  214. _routing_metadata = gen_routing_metadata(
  215. src_party_id=src_party_id, dest_party_id=dest_party_id,
  216. )
  217. for t in range(REQUEST_TRY_TIMES):
  218. channel, stub = get_command_federation_channel(host, port)
  219. try:
  220. _return, _call = stub.unaryCall.with_call(
  221. _packet, metadata=_routing_metadata,
  222. timeout=timeout / 1000 or None,
  223. )
  224. except Exception as e:
  225. schedule_logger(job_id).warning(f'grpc api error: {endpoint}\n{e}')
  226. if t >= REQUEST_TRY_TIMES - 1:
  227. raise e
  228. else:
  229. audit_logger(job_id).info(f'grpc api response: {endpoint}\n{_return}')
  230. return json_loads(_return.body.value)
  231. finally:
  232. channel.close()
  233. time.sleep(get_exponential_backoff_interval(t))
  234. def proxy_api(role, _job_id, request_config):
  235. headers = request_config.get('header', {})
  236. body = request_config.get('body', {})
  237. method = headers.get('METHOD', 'POST')
  238. endpoint = headers.get('ENDPOINT', '')
  239. job_id = headers.get('JOB-ID', _job_id)
  240. src_party_id = headers.get('SRC-PARTY-ID', '')
  241. dest_party_id = headers.get('DEST-PARTY-ID', '')
  242. _packet = forward_grpc_packet(body, method, endpoint, src_party_id, dest_party_id, role, job_id)
  243. _routing_metadata = gen_routing_metadata(src_party_id, dest_party_id)
  244. host, port, protocol = get_federated_proxy_address(src_party_id, dest_party_id)
  245. channel, stub = get_command_federation_channel(host, port)
  246. _return, _call = stub.unaryCall.with_call(_packet, metadata=_routing_metadata)
  247. channel.close()
  248. response = json_loads(_return.body.value)
  249. return response
  250. def forward_api(role, request_config):
  251. role = role.upper()
  252. if not hasattr(ServerRegistry, role):
  253. ServerRegistry.load()
  254. if not hasattr(ServerRegistry, role):
  255. return {'retcode': 404, 'retmsg': f'role "{role.lower()}" not supported'}
  256. registry = getattr(ServerRegistry, role)
  257. headers = request_config.get('header', {})
  258. body = request_config.get('body', {})
  259. method = headers.get('METHOD', 'POST')
  260. endpoint = headers.get('ENDPOINT', '')
  261. ip = registry.get('host', '')
  262. port = registry.get('port', '')
  263. url = f'http://{ip}:{port}{endpoint}'
  264. audit_logger().info(f'api request: {url}')
  265. response = request(method=method, url=url, json=body, headers=headers)
  266. response = (
  267. response.json() if response.status_code == 200
  268. else {'retcode': response.status_code, 'retmsg': response.text}
  269. )
  270. audit_logger().info(response)
  271. return response
  272. def create_job_request_check(func):
  273. @wraps(func)
  274. def _wrapper(*_args, **_kwargs):
  275. party_id = _kwargs.get("party_id")
  276. role = _kwargs.get("role")
  277. body = flask_request.json
  278. headers = flask_request.headers
  279. src_role = headers.get("scr_role")
  280. src_party_id = headers.get("src_party_id")
  281. # permission check
  282. if PERMISSION_SWITCH:
  283. permission_return = HookManager.permission_check(get_permission_parameters(role, party_id, src_role,
  284. src_party_id, body))
  285. if permission_return.code != RetCode.SUCCESS:
  286. return get_json_result(
  287. retcode=RetCode.PERMISSION_ERROR,
  288. retmsg='permission check failed',
  289. data=permission_return.to_dict()
  290. )
  291. # version check
  292. src_fate_ver = headers.get('src_fate_ver')
  293. if src_fate_ver is not None and compare_version(src_fate_ver, '1.7.0') == 'lt':
  294. return get_json_result(retcode=RetCode.INCOMPATIBLE_FATE_VER, retmsg='Incompatible FATE versions',
  295. data={'src_fate_ver': src_fate_ver,
  296. "current_fate_ver": RuntimeConfig.get_env('FATE')})
  297. return func(*_args, **_kwargs)
  298. return _wrapper
  299. def validate_request(*args, **kwargs):
  300. def wrapper(func):
  301. @wraps(func)
  302. def decorated_function(*_args, **_kwargs):
  303. input_arguments = flask_request.json or flask_request.form.to_dict()
  304. no_arguments = []
  305. error_arguments = []
  306. for arg in args:
  307. if arg not in input_arguments:
  308. no_arguments.append(arg)
  309. for k, v in kwargs.items():
  310. config_value = input_arguments.get(k, None)
  311. if config_value is None:
  312. no_arguments.append(k)
  313. elif isinstance(v, (tuple, list)):
  314. if config_value not in v:
  315. error_arguments.append((k, set(v)))
  316. elif config_value != v:
  317. error_arguments.append((k, v))
  318. if no_arguments or error_arguments:
  319. error_string = ""
  320. if no_arguments:
  321. error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
  322. if error_arguments:
  323. error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
  324. return get_json_result(retcode=RetCode.ARGUMENT_ERROR, retmsg=error_string)
  325. return func(*_args, **_kwargs)
  326. return decorated_function
  327. return wrapper
  328. def cluster_route(func):
  329. @wraps(func)
  330. def _route(*args, **kwargs):
  331. request_data = flask_request.json or flask_request.form.to_dict()
  332. instance_id = request_data.get('instance_id')
  333. if not instance_id:
  334. return func(*args, **kwargs)
  335. request_data['forward_times'] = int(request_data.get('forward_times', 0)) + 1
  336. if request_data['forward_times'] > 2:
  337. return error_response(429, 'Too many forwarding times.')
  338. instance = RuntimeConfig.SERVICE_DB.get_servers().get(instance_id)
  339. if instance is None:
  340. return error_response(404, 'Flow Instance not found.')
  341. if instance.http_address == f'{HOST}:{HTTP_PORT}':
  342. return func(*args, **kwargs)
  343. endpoint = flask_request.full_path
  344. prefix = f'/{API_VERSION}/'
  345. if endpoint.startswith(prefix):
  346. endpoint = endpoint[len(prefix) - 1:]
  347. response = cluster_api(
  348. method=flask_request.method,
  349. host=instance.host,
  350. port=instance.http_port,
  351. endpoint=endpoint,
  352. json_body=request_data,
  353. headers=flask_request.headers,
  354. )
  355. return get_json_result(**response)
  356. return _route
  357. def is_localhost(ip):
  358. return ip in {'127.0.0.1', '::1', '[::1]', 'localhost'}
  359. def send_file_in_mem(data, filename):
  360. if not isinstance(data, (str, bytes)):
  361. data = json_dumps(data)
  362. if isinstance(data, str):
  363. data = data.encode('utf-8')
  364. f = BytesIO()
  365. f.write(data)
  366. f.seek(0)
  367. return send_file(f, as_attachment=True, attachment_filename=filename)