123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- import os
- import time
- from contextlib import closing
- import requests
- from requests_toolbelt import MultipartEncoder
- from fate_arch.common.data_utils import default_output_info
- from fate_arch.session import Session
- from fate_flow.components._base import ComponentMeta, BaseParam, ComponentBase, ComponentInputProtocol
- from fate_flow.db.service_registry import ServiceRegistry
- from fate_flow.entity import Metric
- from fate_flow.settings import TEMP_DIRECTORY
- from fate_flow.utils.data_utils import convert_output
- from fate_flow.utils.log_utils import getLogger
- from fate_flow.utils.upload_utils import UploadFile
- logger = getLogger()
- api_reader_cpn_meta = ComponentMeta("ApiReader")
- @api_reader_cpn_meta.bind_param
- class ApiReaderParam(BaseParam):
- def __init__(
- self,
- server_name=None,
- parameters=None,
- id_delimiter=",",
- head=True,
- extend_sid=False,
- timeout=60 * 12
- ):
- self.server_name = server_name
- self.parameters = parameters
- self.id_delimiter = id_delimiter
- self.head = head
- self.extend_sid = extend_sid
- self.timeout = timeout
- def check(self):
- return True
- @api_reader_cpn_meta.bind_runner.on_guest.on_host
- class ApiReader(ComponentBase):
- def __init__(self):
- super(ApiReader, self).__init__()
- self.parameters = {}
- self.required_url_key_list = ["upload", "query", "download"]
- self.service_info = {}
- def _run(self, cpn_input: ComponentInputProtocol):
- self.cpn_input = cpn_input
- self.parameters = cpn_input.parameters
- self.task_dir = os.path.join(TEMP_DIRECTORY, self.tracker.task_id, str(self.tracker.task_version))
- for cpn_name, data in cpn_input.datasets.items():
- for data_name, table_list in data.items():
- self.input_table = table_list[0]
- logger.info(f"parameters: {self.parameters}")
- if not self.parameters.get("server_name"):
- self._run_guest()
- else:
- self._run_host()
- def _run_guest(self):
- self.data_output = [self.input_table]
- def _run_host(self):
- self.set_service_registry_info()
- response = self.upload_data()
- logger.info(f"upload response: {response.text}")
- if response.status_code == 200:
- response_data = response.json()
- if response_data.get("code") == 0:
- logger.info(f"request success, start check status")
- job_id = response_data.get("data").get("jobId")
- status = self.check_status(job_id)
- if status:
- download_path = self.download_data(job_id)
- table, output_name, output_namespace = self.output_feature_table()
- count = UploadFile.upload(
- download_path,
- head=self.parameters.get("head"),
- table=table,
- id_delimiter=self.parameters.get("id_delimiter"),
- extend_sid=self.parameters.get("extend_sid")
- )
- table.meta.update_metas(count=count)
- self.tracker.log_output_data_info(
- data_name=self.cpn_input.flow_feeded_parameters.get("output_data_name")[0],
- table_namespace=output_namespace,
- table_name=output_name,
- )
- self.tracker.log_metric_data(
- metric_namespace="api_reader",
- metric_name="upload",
- metrics=[Metric("count", count)],
- )
- else:
- raise Exception(f"upload return: {response.text}")
- def output_feature_table(self):
- (
- output_name,
- output_namespace
- ) = default_output_info(
- task_id=self.tracker.task_id,
- task_version=self.tracker.task_version,
- output_type="data"
- )
- logger.info(f"flow_feeded_parameters: {self.cpn_input.flow_feeded_parameters}")
- input_table_info = self.cpn_input.flow_feeded_parameters.get("table_info")[0]
- _, output_table_address, output_table_engine = convert_output(
- input_table_info["name"],
- input_table_info["namespace"],
- output_name,
- output_namespace, self.input_table.engine
- )
- sess = Session.get_global()
- output_table_session = sess.storage(storage_engine=output_table_engine)
- table = output_table_session.create_table(
- address=output_table_address,
- name=output_name,
- namespace=output_namespace,
- partitions=self.input_table.partitions,
- )
- return table, output_name, output_namespace
- def check_status(self, job_id):
- query_registry_info = self.service_info.get("query")
- logger.info(f"parameters timeout: {self.parameters.get('timeout', 60 * 12)} min")
- for i in range(0, self.parameters.get("timeout", 60 * 12)):
- status_response = getattr(requests, query_registry_info.f_method.lower(), None)(
- url=query_registry_info.f_url,
- json={"jobId": job_id}
- )
- logger.info(f"status: {status_response.text}")
- if status_response.status_code == 200:
- if status_response.json().get("data").get("status").lower() == "success":
- logger.info(f"job id {job_id} status success, start download")
- return True
- if status_response.json().get("data").get("status").lower() != "running":
- logger.error(f"job id {job_id} status: {status_response.json().get('data').get('status')}")
- raise Exception(status_response.json().get("data"))
- logger.info(f"job id {job_id} status: {status_response.json().get('data').get('status')}")
- time.sleep(60)
- raise TimeoutError("check status timeout")
- def download_data(self, job_id):
- download_registry_info = self.service_info.get("download")
- download_path = os.path.join(self.task_dir, "features")
- logger.info(f"start download feature, url: {download_registry_info.f_url}")
- params = {"jobId": job_id}
- with closing(getattr(requests, download_registry_info.f_method.lower(), None)(
- url=download_registry_info.f_url,
- params={"requestBody": json.dumps(params)},
- stream=True)) as response:
- if response.status_code == 200:
- with open(download_path, 'wb') as fw:
- for chunk in response.iter_content(1024):
- if chunk:
- fw.write(chunk)
- else:
- raise Exception(f"download return: {response.text}")
- return download_path
- def upload_data(self):
- id_path = os.path.join(self.task_dir, "id")
- logger.info(f"save to: {id_path}")
- os.makedirs(os.path.dirname(id_path), exist_ok=True)
- with open(id_path, "w") as f:
- for k, _ in self.input_table.collect():
- f.write(f"{k}\n")
- data = MultipartEncoder(
- fields={'file': (id_path, f, 'application/octet-stream')}
- )
- upload_registry_info = self.service_info.get("upload")
- logger.info(f"upload info:{upload_registry_info.to_dict()}")
- response = getattr(requests, upload_registry_info.f_method.lower(), None)(
- url=upload_registry_info.f_url,
- params={"requestBody": json.dumps(self.parameters.get("parameters", {}))},
- data=data,
- headers={'Content-Type': data.content_type}
- )
- return response
- def set_service_registry_info(self):
- for info in ServiceRegistry().load_service(server_name=self.parameters.get("server_name")):
- for key in self.required_url_key_list:
- if key == info.f_service_name:
- self.service_info[key] = info
- logger.info(f"set service registry info:{self.service_info}")
|