1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- from concurrent import futures
- import grpc
- from easyfl.pb import client_service_pb2_grpc as client_grpc
- from easyfl.pb import server_service_pb2_grpc as server_grpc
- from easyfl.pb import tracking_service_pb2_grpc as tracking_grpc
- MAX_MESSAGE_LENGTH = 524288000 # 500MB
- TYPE_CLIENT = "client"
- TYPE_SERVER = "server"
- TYPE_TRACKING = "tracking"
- def init_stub(typ, address):
- """Initialize gRPC stub.
- Args:
- typ (str): Type of service, option: client, server, tracking
- address (str): Address of the gRPC service.
- Returns:
- (:obj:`ClientServiceStub`|:obj:`ServerServiceStub`|:obj:`TrackingServiceStub`): stub of the gRPC service.
- """
- channel = grpc.insecure_channel(
- address,
- options=[
- ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
- ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
- ],
- )
- if typ == TYPE_CLIENT:
- stub = client_grpc.ClientServiceStub(channel)
- elif typ == TYPE_TRACKING:
- stub = tracking_grpc.TrackingServiceStub(channel)
- else:
- stub = server_grpc.ServerServiceStub(channel)
- return stub
- def start_service(typ, service, port):
- """Start gRPC service.
- Args:
- typ (str): Type of service, option: client, server, tracking.
- service (:obj:`ClientService`|:obj:`ServerService`|:obj:`TrackingService`): gRPC service to start.
- port (int): The port of the service.
- """
- server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=10),
- options=[
- ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
- ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
- ],
- )
- if typ == TYPE_CLIENT:
- client_grpc.add_ClientServiceServicer_to_server(service, server)
- elif typ == TYPE_TRACKING:
- tracking_grpc.add_TrackingServiceServicer_to_server(service, server)
- else:
- server_grpc.add_ServerServiceServicer_to_server(service, server)
- server.add_insecure_port('[::]:{}'.format(port))
- server.start()
- server.wait_for_termination()
- def endpoint(host, port):
- """Format endpoint.
- Args:
- host (str): Host address.
- port (int): Port number.
- Returns:
- str: Address in `host:port` format.
- """
- return "{}:{}".format(host, port)
|