api_reader.py 8.7 KB


  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import os
  18. import time
  19. from contextlib import closing
  20. import requests
  21. from requests_toolbelt import MultipartEncoder
  22. from fate_arch.common.data_utils import default_output_info
  23. from fate_arch.session import Session
  24. from fate_flow.components._base import ComponentMeta, BaseParam, ComponentBase, ComponentInputProtocol
  25. from fate_flow.db.service_registry import ServiceRegistry
  26. from fate_flow.entity import Metric
  27. from fate_flow.settings import TEMP_DIRECTORY
  28. from fate_flow.utils.data_utils import convert_output
  29. from fate_flow.utils.log_utils import getLogger
  30. from fate_flow.utils.upload_utils import UploadFile
  31. logger = getLogger()
  32. api_reader_cpn_meta = ComponentMeta("ApiReader")
  33. @api_reader_cpn_meta.bind_param
  34. class ApiReaderParam(BaseParam):
  35. def __init__(
  36. self,
  37. server_name=None,
  38. parameters=None,
  39. id_delimiter=",",
  40. head=True,
  41. extend_sid=False,
  42. timeout=60 * 12
  43. ):
  44. self.server_name = server_name
  45. self.parameters = parameters
  46. self.id_delimiter = id_delimiter
  47. self.head = head
  48. self.extend_sid = extend_sid
  49. self.timeout = timeout
  50. def check(self):
  51. return True
  52. @api_reader_cpn_meta.bind_runner.on_guest.on_host
  53. class ApiReader(ComponentBase):
  54. def __init__(self):
  55. super(ApiReader, self).__init__()
  56. self.parameters = {}
  57. self.required_url_key_list = ["upload", "query", "download"]
  58. self.service_info = {}
  59. def _run(self, cpn_input: ComponentInputProtocol):
  60. self.cpn_input = cpn_input
  61. self.parameters = cpn_input.parameters
  62. self.task_dir = os.path.join(TEMP_DIRECTORY, self.tracker.task_id, str(self.tracker.task_version))
  63. for cpn_name, data in cpn_input.datasets.items():
  64. for data_name, table_list in data.items():
  65. self.input_table = table_list[0]
  66. logger.info(f"parameters: {self.parameters}")
  67. if not self.parameters.get("server_name"):
  68. self._run_guest()
  69. else:
  70. self._run_host()
  71. def _run_guest(self):
  72. self.data_output = [self.input_table]
  73. def _run_host(self):
  74. self.set_service_registry_info()
  75. response = self.upload_data()
  76. logger.info(f"upload response: {response.text}")
  77. if response.status_code == 200:
  78. response_data = response.json()
  79. if response_data.get("code") == 0:
  80. logger.info(f"request success, start check status")
  81. job_id = response_data.get("data").get("jobId")
  82. status = self.check_status(job_id)
  83. if status:
  84. download_path = self.download_data(job_id)
  85. table, output_name, output_namespace = self.output_feature_table()
  86. count = UploadFile.upload(
  87. download_path,
  88. head=self.parameters.get("head"),
  89. table=table,
  90. id_delimiter=self.parameters.get("id_delimiter"),
  91. extend_sid=self.parameters.get("extend_sid")
  92. )
  93. table.meta.update_metas(count=count)
  94. self.tracker.log_output_data_info(
  95. data_name=self.cpn_input.flow_feeded_parameters.get("output_data_name")[0],
  96. table_namespace=output_namespace,
  97. table_name=output_name,
  98. )
  99. self.tracker.log_metric_data(
  100. metric_namespace="api_reader",
  101. metric_name="upload",
  102. metrics=[Metric("count", count)],
  103. )
  104. else:
  105. raise Exception(f"upload return: {response.text}")
  106. def output_feature_table(self):
  107. (
  108. output_name,
  109. output_namespace
  110. ) = default_output_info(
  111. task_id=self.tracker.task_id,
  112. task_version=self.tracker.task_version,
  113. output_type="data"
  114. )
  115. logger.info(f"flow_feeded_parameters: {self.cpn_input.flow_feeded_parameters}")
  116. input_table_info = self.cpn_input.flow_feeded_parameters.get("table_info")[0]
  117. _, output_table_address, output_table_engine = convert_output(
  118. input_table_info["name"],
  119. input_table_info["namespace"],
  120. output_name,
  121. output_namespace, self.input_table.engine
  122. )
  123. sess = Session.get_global()
  124. output_table_session = sess.storage(storage_engine=output_table_engine)
  125. table = output_table_session.create_table(
  126. address=output_table_address,
  127. name=output_name,
  128. namespace=output_namespace,
  129. partitions=self.input_table.partitions,
  130. )
  131. return table, output_name, output_namespace
  132. def check_status(self, job_id):
  133. query_registry_info = self.service_info.get("query")
  134. logger.info(f"parameters timeout: {self.parameters.get('timeout', 60 * 12)} min")
  135. for i in range(0, self.parameters.get("timeout", 60 * 12)):
  136. status_response = getattr(requests, query_registry_info.f_method.lower(), None)(
  137. url=query_registry_info.f_url,
  138. json={"jobId": job_id}
  139. )
  140. logger.info(f"status: {status_response.text}")
  141. if status_response.status_code == 200:
  142. if status_response.json().get("data").get("status").lower() == "success":
  143. logger.info(f"job id {job_id} status success, start download")
  144. return True
  145. if status_response.json().get("data").get("status").lower() != "running":
  146. logger.error(f"job id {job_id} status: {status_response.json().get('data').get('status')}")
  147. raise Exception(status_response.json().get("data"))
  148. logger.info(f"job id {job_id} status: {status_response.json().get('data').get('status')}")
  149. time.sleep(60)
  150. raise TimeoutError("check status timeout")
  151. def download_data(self, job_id):
  152. download_registry_info = self.service_info.get("download")
  153. download_path = os.path.join(self.task_dir, "features")
  154. logger.info(f"start download feature, url: {download_registry_info.f_url}")
  155. params = {"jobId": job_id}
  156. with closing(getattr(requests, download_registry_info.f_method.lower(), None)(
  157. url=download_registry_info.f_url,
  158. params={"requestBody": json.dumps(params)},
  159. stream=True)) as response:
  160. if response.status_code == 200:
  161. with open(download_path, 'wb') as fw:
  162. for chunk in response.iter_content(1024):
  163. if chunk:
  164. fw.write(chunk)
  165. else:
  166. raise Exception(f"download return: {response.text}")
  167. return download_path
  168. def upload_data(self):
  169. id_path = os.path.join(self.task_dir, "id")
  170. logger.info(f"save to: {id_path}")
  171. os.makedirs(os.path.dirname(id_path), exist_ok=True)
  172. with open(id_path, "w") as f:
  173. for k, _ in self.input_table.collect():
  174. f.write(f"{k}\n")
  175. data = MultipartEncoder(
  176. fields={'file': (id_path, f, 'application/octet-stream')}
  177. )
  178. upload_registry_info = self.service_info.get("upload")
  179. logger.info(f"upload info:{upload_registry_info.to_dict()}")
  180. response = getattr(requests, upload_registry_info.f_method.lower(), None)(
  181. url=upload_registry_info.f_url,
  182. params={"requestBody": json.dumps(self.parameters.get("parameters", {}))},
  183. data=data,
  184. headers={'Content-Type': data.content_type}
  185. )
  186. return response
  187. def set_service_registry_info(self):
  188. for info in ServiceRegistry().load_service(server_name=self.parameters.get("server_name")):
  189. for key in self.required_url_key_list:
  190. if key == info.f_service_name:
  191. self.service_info[key] = info
  192. logger.info(f"set service registry info:{self.service_info}")