1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import time
- import sys
- import grpc
- import requests
- from grpc._cython import cygrpc
- from fate_arch.protobuf.python import basic_meta_pb2, proxy_pb2, proxy_pb2_grpc
- from fate_arch.common.base_utils import json_dumps, json_loads
- from fate_flow.db.runtime_config import RuntimeConfig
- from fate_flow.db.job_default_config import JobDefaultConfig
- from fate_flow.settings import FATE_FLOW_SERVICE_NAME, stat_logger, HOST, GRPC_PORT
- from fate_flow.tests.grpc.xthread import ThreadPoolExecutor
- def wrap_grpc_packet(json_body, http_method, url, src_party_id, dst_party_id, job_id=None, overall_timeout=None):
- overall_timeout = JobDefaultConfig.remote_request_timeout if overall_timeout is None else overall_timeout
- _src_end_point = basic_meta_pb2.Endpoint(ip=HOST, port=GRPC_PORT)
- _src = proxy_pb2.Topic(name=job_id, partyId="{}".format(src_party_id), role=FATE_FLOW_SERVICE_NAME, callback=_src_end_point)
- _dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(dst_party_id), role=FATE_FLOW_SERVICE_NAME, callback=None)
- _task = proxy_pb2.Task(taskId=job_id)
- _command = proxy_pb2.Command(name=FATE_FLOW_SERVICE_NAME)
- _conf = proxy_pb2.Conf(overallTimeout=overall_timeout)
- _meta = proxy_pb2.Metadata(src=_src, dst=_dst, task=_task, command=_command, operator=http_method, conf=_conf)
- _data = proxy_pb2.Data(key=url, value=bytes(json_dumps(json_body), 'utf-8'))
- return proxy_pb2.Packet(header=_meta, body=_data)
- def get_url(_suffix):
- return "http://{}:{}/{}".format(RuntimeConfig.JOB_SERVER_HOST, RuntimeConfig.HTTP_PORT, _suffix.lstrip('/'))
- class UnaryService(proxy_pb2_grpc.DataTransferServiceServicer):
- @staticmethod
- def unaryCall(_request, context):
- packet = _request
- header = packet.header
- _suffix = packet.body.key
- param_bytes = packet.body.value
- param = bytes.decode(param_bytes)
- job_id = header.task.taskId
- src = header.src
- dst = header.dst
- method = header.operator
- param_dict = json_loads(param)
- param_dict['src_party_id'] = str(src.partyId)
- source_routing_header = []
- for key, value in context.invocation_metadata():
- source_routing_header.append((key, value))
- stat_logger.info(f"grpc request routing header: {source_routing_header}")
- action = getattr(requests, method.lower(), None)
- if action:
- print(_suffix)
- else:
- pass
- resp_json = {"status": "test"}
- import time
- print("sleep")
- time.sleep(60)
- return wrap_grpc_packet(resp_json, method, _suffix, dst.partyId, src.partyId, job_id)
- thread_pool_executor = ThreadPoolExecutor(max_workers=5)
- print(f"start grpc server pool on {thread_pool_executor._max_workers} max workers")
- server = grpc.server(thread_pool_executor,
- options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
- (cygrpc.ChannelArgKey.max_receive_message_length, -1)])
- proxy_pb2_grpc.add_DataTransferServiceServicer_to_server(UnaryService(), server)
- server.add_insecure_port("{}:{}".format("127.0.0.1", 9360))
- server.start()
- try:
- while True:
- time.sleep(60 * 60 * 24)
- except KeyboardInterrupt:
- server.stop(0)
- sys.exit(0)
|