123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import grpc
- from fate_arch.protobuf.python import basic_meta_pb2, proxy_pb2, proxy_pb2_grpc
- from fate_arch.common.base_utils import json_dumps, json_loads
- from fate_flow.db.job_default_config import JobDefaultConfig
- from fate_flow.db.runtime_config import RuntimeConfig
- from fate_flow.settings import FATE_FLOW_SERVICE_NAME, GRPC_OPTIONS, GRPC_PORT, HOST
- from fate_flow.utils.log_utils import audit_logger
- from fate_flow.utils.requests_utils import request
- def get_command_federation_channel(host, port):
- channel = grpc.insecure_channel(f"{host}:{port}", GRPC_OPTIONS)
- stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
- return channel, stub
- def gen_routing_metadata(src_party_id, dest_party_id):
- routing_head = (
- ("service", "fateflow"),
- ("src-party-id", str(src_party_id)),
- ("src-role", "guest"),
- ("dest-party-id", str(dest_party_id)),
- ("dest-role", "host"),
- )
- return routing_head
- def wrap_grpc_packet(json_body, http_method, url, src_party_id, dst_party_id, job_id=None, headers=None, overall_timeout=None):
- overall_timeout = JobDefaultConfig.remote_request_timeout if overall_timeout is None else overall_timeout
- _src_end_point = basic_meta_pb2.Endpoint(ip=HOST, port=GRPC_PORT)
- _src = proxy_pb2.Topic(name=job_id, partyId="{}".format(src_party_id), role=FATE_FLOW_SERVICE_NAME, callback=_src_end_point)
- _dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(dst_party_id), role=FATE_FLOW_SERVICE_NAME, callback=None)
- _model = proxy_pb2.Model(name="headers", dataKey=json_dumps(headers))
- _task = proxy_pb2.Task(taskId=job_id, model=_model)
- _command = proxy_pb2.Command(name=url)
- _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
- _meta = proxy_pb2.Metadata(src=_src, dst=_dst, task=_task, command=_command, operator=http_method, conf=_conf)
- _data = proxy_pb2.Data(key=url, value=bytes(json_dumps(json_body), 'utf-8'))
- return proxy_pb2.Packet(header=_meta, body=_data)
- def get_url(_suffix):
- return "http://{}:{}/{}".format(RuntimeConfig.JOB_SERVER_HOST, RuntimeConfig.HTTP_PORT, _suffix.lstrip('/'))
- class UnaryService(proxy_pb2_grpc.DataTransferServiceServicer):
- @staticmethod
- def unaryCall(_request, context):
- packet = _request
- header = packet.header
- _suffix = packet.body.key
- param_bytes = packet.body.value
- param = bytes.decode(param_bytes)
- job_id = header.task.taskId
- src = header.src
- dst = header.dst
- headers_str = header.task.model.dataKey if header.task.model.dataKey else "{}"
- headers = json_loads(headers_str)
- method = header.operator
- param_dict = json_loads(param)
- source_routing_header = []
- for key, value in context.invocation_metadata():
- source_routing_header.append((key, value))
- _routing_metadata = gen_routing_metadata(src_party_id=src.partyId, dest_party_id=dst.partyId)
- context.set_trailing_metadata(trailing_metadata=_routing_metadata)
- audit_logger(job_id).info("rpc receive headers: {}".format(headers))
- audit_logger(job_id).info('rpc receive: {}'.format(packet))
- audit_logger(job_id).info("rpc receive: {} {}".format(get_url(_suffix), param))
- resp = request(method=method, url=get_url(_suffix), json=param_dict, headers=headers)
- resp_json = resp.json()
- return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId, src.partyId, job_id)
- def forward_grpc_packet(_json_body, _method, _url, _src_party_id, _dst_party_id, role, job_id=None,
- overall_timeout=None):
- overall_timeout = JobDefaultConfig.remote_request_timeout if overall_timeout is None else overall_timeout
- _src_end_point = basic_meta_pb2.Endpoint(ip=HOST, port=GRPC_PORT)
- _src = proxy_pb2.Topic(name=job_id, partyId="{}".format(_src_party_id), role=FATE_FLOW_SERVICE_NAME, callback=_src_end_point)
- _dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(_dst_party_id), role=role, callback=None)
- _task = proxy_pb2.Task(taskId=job_id)
- _command = proxy_pb2.Command(name=_url)
- _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
- _meta = proxy_pb2.Metadata(src=_src, dst=_dst, task=_task, command=_command, operator=_method, conf=_conf)
- _data = proxy_pb2.Data(key=_url, value=bytes(json_dumps(_json_body), 'utf-8'))
- return proxy_pb2.Packet(header=_meta, body=_data)
|