job.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import json
  18. from contextlib import closing
  19. from flow_sdk.client.api.base import BaseFlowAPI
  20. from flow_sdk.utils import preprocess, check_config, download_from_request
  21. class Job(BaseFlowAPI):
  22. def list(self, limit=10):
  23. kwargs = locals()
  24. config_data, dsl_data = preprocess(**kwargs)
  25. return self._post(url='job/list/job', json=config_data)
  26. def view(self, job_id=None, role=None, party_id=None, status=None):
  27. kwargs = locals()
  28. config_data, dsl_data = preprocess(**kwargs)
  29. return self._post(url='job/data/view/query', json=config_data)
  30. def submit(self, config_data, dsl_data=None):
  31. kwargs = locals()
  32. config_data, dsl_data = preprocess(**kwargs)
  33. return self._post(url='job/submit', json={
  34. 'job_runtime_conf': config_data,
  35. 'job_dsl': dsl_data,
  36. })
  37. def stop(self, job_id):
  38. job_id = str(job_id)
  39. kwargs = locals()
  40. config_data, dsl_data = preprocess(**kwargs)
  41. check_config(config=config_data, required_arguments=['job_id'])
  42. return self._post(url='job/stop', json=config_data)
  43. def query(self, job_id=None, role=None, party_id=None, component_name=None, status=None):
  44. kwargs = locals()
  45. config_data, dsl_data = preprocess(**kwargs)
  46. return self._post(url='job/query', json=config_data)
  47. def config(self, job_id, role, party_id, output_path):
  48. kwargs = locals()
  49. config_data, dsl_data = preprocess(**kwargs)
  50. check_config(config=config_data, required_arguments=['job_id', 'role', 'party_id', 'output_path'])
  51. response = self._post(url='job/config', json=config_data)
  52. if response['retcode'] == 0:
  53. job_id = response['data']['job_id']
  54. download_directory = os.path.join(config_data['output_path'], 'job_{}_config'.format(job_id))
  55. os.makedirs(download_directory, exist_ok=True)
  56. for k, v in response['data'].items():
  57. if k == 'job_id':
  58. continue
  59. with open('{}/{}.json'.format(download_directory, k), 'w') as fw:
  60. json.dump(v, fw, indent=4)
  61. del response['data']['dsl']
  62. del response['data']['runtime_conf']
  63. response['directory'] = download_directory
  64. response['retmsg'] = 'download successfully, please check {} directory'.format(download_directory)
  65. return response
  66. def log(self, job_id, output_path):
  67. kwargs = locals()
  68. config_data, dsl_data = preprocess(**kwargs)
  69. check_config(config=config_data, required_arguments=['job_id', 'output_path'])
  70. job_id = config_data['job_id']
  71. tar_file_name = 'job_{}_log.tar.gz'.format(job_id)
  72. extract_dir = os.path.join(config_data['output_path'], 'job_{}_log'.format(job_id))
  73. with closing(self._post(url='job/log/download', handle_result=False, json=config_data, stream=True)) as response:
  74. if response.status_code == 200:
  75. download_from_request(http_response=response, tar_file_name=tar_file_name, extract_dir=extract_dir)
  76. response = {'retcode': 0,
  77. 'directory': extract_dir,
  78. 'retmsg': 'download successfully, please check {} directory'.format(extract_dir)}
  79. else:
  80. response = response.json()
  81. return response
  82. def generate_dsl(self, train_dsl, cpn):
  83. """
  84. @param train_dsl: dict or str
  85. @param cpn: list or str
  86. """
  87. if isinstance(train_dsl, dict):
  88. train_dsl = json.dumps(train_dsl)
  89. config_data = {
  90. "cpn_str": cpn,
  91. "train_dsl": train_dsl,
  92. "version": "2"
  93. }
  94. res = self._post(url="job/dsl/generate", handle_result=True, json=config_data)
  95. if not res.get("data"):
  96. res["data"] = {}
  97. return res
  98. # TODO complete it in next version
  99. # def clean(self, job_id=None, role=None, party_id=None, component_name=None):
  100. # kwargs = locals()
  101. # config_data, dsl_data = preprocess(**kwargs)
  102. # check_config(config=config_data, required_arguments=['job_id'])
  103. # return self._post(url='job/clean', json=config_data)