server.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import time
  17. import sys
  18. import grpc
  19. import requests
  20. from grpc._cython import cygrpc
  21. from fate_arch.protobuf.python import basic_meta_pb2, proxy_pb2, proxy_pb2_grpc
  22. from fate_arch.common.base_utils import json_dumps, json_loads
  23. from fate_flow.db.runtime_config import RuntimeConfig
  24. from fate_flow.db.job_default_config import JobDefaultConfig
  25. from fate_flow.settings import FATE_FLOW_SERVICE_NAME, stat_logger, HOST, GRPC_PORT
  26. from fate_flow.tests.grpc.xthread import ThreadPoolExecutor
  27. def wrap_grpc_packet(json_body, http_method, url, src_party_id, dst_party_id, job_id=None, overall_timeout=None):
  28. overall_timeout = JobDefaultConfig.remote_request_timeout if overall_timeout is None else overall_timeout
  29. _src_end_point = basic_meta_pb2.Endpoint(ip=HOST, port=GRPC_PORT)
  30. _src = proxy_pb2.Topic(name=job_id, partyId="{}".format(src_party_id), role=FATE_FLOW_SERVICE_NAME, callback=_src_end_point)
  31. _dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(dst_party_id), role=FATE_FLOW_SERVICE_NAME, callback=None)
  32. _task = proxy_pb2.Task(taskId=job_id)
  33. _command = proxy_pb2.Command(name=FATE_FLOW_SERVICE_NAME)
  34. _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
  35. _meta = proxy_pb2.Metadata(src=_src, dst=_dst, task=_task, command=_command, operator=http_method, conf=_conf)
  36. _data = proxy_pb2.Data(key=url, value=bytes(json_dumps(json_body), 'utf-8'))
  37. return proxy_pb2.Packet(header=_meta, body=_data)
  38. def get_url(_suffix):
  39. return "http://{}:{}/{}".format(RuntimeConfig.JOB_SERVER_HOST, RuntimeConfig.HTTP_PORT, _suffix.lstrip('/'))
  40. class UnaryService(proxy_pb2_grpc.DataTransferServiceServicer):
  41. @staticmethod
  42. def unaryCall(_request, context):
  43. packet = _request
  44. header = packet.header
  45. _suffix = packet.body.key
  46. param_bytes = packet.body.value
  47. param = bytes.decode(param_bytes)
  48. job_id = header.task.taskId
  49. src = header.src
  50. dst = header.dst
  51. method = header.operator
  52. param_dict = json_loads(param)
  53. param_dict['src_party_id'] = str(src.partyId)
  54. source_routing_header = []
  55. for key, value in context.invocation_metadata():
  56. source_routing_header.append((key, value))
  57. stat_logger.info(f"grpc request routing header: {source_routing_header}")
  58. action = getattr(requests, method.lower(), None)
  59. if action:
  60. print(_suffix)
  61. else:
  62. pass
  63. resp_json = {"status": "test"}
  64. import time
  65. print("sleep")
  66. time.sleep(60)
  67. return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId, src.partyId, job_id)
  68. thread_pool_executor = ThreadPoolExecutor(max_workers=5)
  69. print(f"start grpc server pool on {thread_pool_executor._max_workers} max workers")
  70. server = grpc.server(thread_pool_executor,
  71. options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
  72. (cygrpc.ChannelArgKey.max_receive_message_length, -1)])
  73. proxy_pb2_grpc.add_DataTransferServiceServicer_to_server(UnaryService(), server)
  74. server.add_insecure_port("{}:{}".format("127.0.0.1", 9360))
  75. server.start()
  76. try:
  77. while True:
  78. time.sleep(60 * 60 * 24)
  79. except KeyboardInterrupt:
  80. server.stop(0)
  81. sys.exit(0)