task_utils.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import functools
  2. from flask import request as flask_request
  3. from fate_flow.db.runtime_config import RuntimeConfig
  4. from fate_flow.entity import RetCode
  5. from fate_flow.operation.job_saver import JobSaver
  6. from fate_flow.utils.api_utils import get_json_result
  7. from fate_flow.utils.requests_utils import request
  8. def task_request_proxy(filter_local=False, force=True):
  9. def _outer(func):
  10. @functools.wraps(func)
  11. def _wrapper(*args, **kwargs):
  12. party_id, role, task_id, task_version = kwargs.get("party_id"), kwargs.get("role"), \
  13. kwargs.get("task_id"), kwargs.get("task_version")
  14. if not filter_local or (filter_local and role == "local"):
  15. tasks = JobSaver.query_task(task_id=task_id, task_version=task_version, role=role, party_id=party_id)
  16. if tasks:
  17. if tasks[0].f_run_ip and tasks[0].f_run_port:
  18. if tasks[0].f_run_ip != RuntimeConfig.JOB_SERVER_HOST:
  19. source_url = flask_request.url
  20. source_address = source_url.split("/")[2]
  21. dest_address = ":".join([tasks[0].f_run_ip, str(tasks[0].f_run_port)])
  22. dest_url = source_url.replace(source_address, dest_address)
  23. try:
  24. response = request(method=flask_request.method, url=dest_url, json=flask_request.json, headers=flask_request.headers)
  25. if 200 <= response.status_code < 300:
  26. response = response.json()
  27. return get_json_result(retcode=response.get("retcode"),
  28. retmsg=response.get('retmsg'))
  29. else:
  30. raise Exception(f"status_code: {response.status_code}, text: {response.text}")
  31. except Exception as e:
  32. if force:
  33. return func(*args, **kwargs)
  34. raise e
  35. else:
  36. return get_json_result(retcode=RetCode.DATA_ERROR, retmsg='no found task')
  37. return func(*args, **kwargs)
  38. return _wrapper
  39. return _outer