grpc_utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import grpc
  17. from fate_arch.protobuf.python import basic_meta_pb2, proxy_pb2, proxy_pb2_grpc
  18. from fate_arch.common.base_utils import json_dumps, json_loads
  19. from fate_flow.db.job_default_config import JobDefaultConfig
  20. from fate_flow.db.runtime_config import RuntimeConfig
  21. from fate_flow.settings import FATE_FLOW_SERVICE_NAME, GRPC_OPTIONS, GRPC_PORT, HOST
  22. from fate_flow.utils.log_utils import audit_logger
  23. from fate_flow.utils.requests_utils import request
  24. def get_command_federation_channel(host, port):
  25. channel = grpc.insecure_channel(f"{host}:{port}", GRPC_OPTIONS)
  26. stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
  27. return channel, stub
  28. def gen_routing_metadata(src_party_id, dest_party_id):
  29. routing_head = (
  30. ("service", "fateflow"),
  31. ("src-party-id", str(src_party_id)),
  32. ("src-role", "guest"),
  33. ("dest-party-id", str(dest_party_id)),
  34. ("dest-role", "host"),
  35. )
  36. return routing_head
  37. def wrap_grpc_packet(json_body, http_method, url, src_party_id, dst_party_id, job_id=None, headers=None, overall_timeout=None):
  38. overall_timeout = JobDefaultConfig.remote_request_timeout if overall_timeout is None else overall_timeout
  39. _src_end_point = basic_meta_pb2.Endpoint(ip=HOST, port=GRPC_PORT)
  40. _src = proxy_pb2.Topic(name=job_id, partyId="{}".format(src_party_id), role=FATE_FLOW_SERVICE_NAME, callback=_src_end_point)
  41. _dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(dst_party_id), role=FATE_FLOW_SERVICE_NAME, callback=None)
  42. _model = proxy_pb2.Model(name="headers", dataKey=json_dumps(headers))
  43. _task = proxy_pb2.Task(taskId=job_id, model=_model)
  44. _command = proxy_pb2.Command(name=url)
  45. _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
  46. _meta = proxy_pb2.Metadata(src=_src, dst=_dst, task=_task, command=_command, operator=http_method, conf=_conf)
  47. _data = proxy_pb2.Data(key=url, value=bytes(json_dumps(json_body), 'utf-8'))
  48. return proxy_pb2.Packet(header=_meta, body=_data)
  49. def get_url(_suffix):
  50. return "http://{}:{}/{}".format(RuntimeConfig.JOB_SERVER_HOST, RuntimeConfig.HTTP_PORT, _suffix.lstrip('/'))
  51. class UnaryService(proxy_pb2_grpc.DataTransferServiceServicer):
  52. @staticmethod
  53. def unaryCall(_request, context):
  54. packet = _request
  55. header = packet.header
  56. _suffix = packet.body.key
  57. param_bytes = packet.body.value
  58. param = bytes.decode(param_bytes)
  59. job_id = header.task.taskId
  60. src = header.src
  61. dst = header.dst
  62. headers_str = header.task.model.dataKey if header.task.model.dataKey else "{}"
  63. headers = json_loads(headers_str)
  64. method = header.operator
  65. param_dict = json_loads(param)
  66. source_routing_header = []
  67. for key, value in context.invocation_metadata():
  68. source_routing_header.append((key, value))
  69. _routing_metadata = gen_routing_metadata(src_party_id=src.partyId, dest_party_id=dst.partyId)
  70. context.set_trailing_metadata(trailing_metadata=_routing_metadata)
  71. audit_logger(job_id).info("rpc receive headers: {}".format(headers))
  72. audit_logger(job_id).info('rpc receive: {}'.format(packet))
  73. audit_logger(job_id).info("rpc receive: {} {}".format(get_url(_suffix), param))
  74. resp = request(method=method, url=get_url(_suffix), json=param_dict, headers=headers)
  75. resp_json = resp.json()
  76. return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId, src.partyId, job_id)
  77. def forward_grpc_packet(_json_body, _method, _url, _src_party_id, _dst_party_id, role, job_id=None,
  78. overall_timeout=None):
  79. overall_timeout = JobDefaultConfig.remote_request_timeout if overall_timeout is None else overall_timeout
  80. _src_end_point = basic_meta_pb2.Endpoint(ip=HOST, port=GRPC_PORT)
  81. _src = proxy_pb2.Topic(name=job_id, partyId="{}".format(_src_party_id), role=FATE_FLOW_SERVICE_NAME, callback=_src_end_point)
  82. _dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(_dst_party_id), role=role, callback=None)
  83. _task = proxy_pb2.Task(taskId=job_id)
  84. _command = proxy_pb2.Command(name=_url)
  85. _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
  86. _meta = proxy_pb2.Metadata(src=_src, dst=_dst, task=_task, command=_command, operator=_method, conf=_conf)
  87. _data = proxy_pb2.Data(key=_url, value=bytes(json_dumps(_json_body), 'utf-8'))
  88. return proxy_pb2.Packet(header=_meta, body=_data)