grpc_wrapper.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from concurrent import futures
  2. import grpc
  3. from easyfl.pb import client_service_pb2_grpc as client_grpc
  4. from easyfl.pb import server_service_pb2_grpc as server_grpc
  5. from easyfl.pb import tracking_service_pb2_grpc as tracking_grpc
  6. MAX_MESSAGE_LENGTH = 524288000 # 500MB
  7. TYPE_CLIENT = "client"
  8. TYPE_SERVER = "server"
  9. TYPE_TRACKING = "tracking"
  10. def init_stub(typ, address):
  11. """Initialize gRPC stub.
  12. Args:
  13. typ (str): Type of service, option: client, server, tracking
  14. address (str): Address of the gRPC service.
  15. Returns:
  16. (:obj:`ClientServiceStub`|:obj:`ServerServiceStub`|:obj:`TrackingServiceStub`): stub of the gRPC service.
  17. """
  18. channel = grpc.insecure_channel(
  19. address,
  20. options=[
  21. ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
  22. ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
  23. ],
  24. )
  25. if typ == TYPE_CLIENT:
  26. stub = client_grpc.ClientServiceStub(channel)
  27. elif typ == TYPE_TRACKING:
  28. stub = tracking_grpc.TrackingServiceStub(channel)
  29. else:
  30. stub = server_grpc.ServerServiceStub(channel)
  31. return stub
  32. def start_service(typ, service, port):
  33. """Start gRPC service.
  34. Args:
  35. typ (str): Type of service, option: client, server, tracking.
  36. service (:obj:`ClientService`|:obj:`ServerService`|:obj:`TrackingService`): gRPC service to start.
  37. port (int): The port of the service.
  38. """
  39. server = grpc.server(
  40. futures.ThreadPoolExecutor(max_workers=10),
  41. options=[
  42. ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
  43. ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
  44. ],
  45. )
  46. if typ == TYPE_CLIENT:
  47. client_grpc.add_ClientServiceServicer_to_server(service, server)
  48. elif typ == TYPE_TRACKING:
  49. tracking_grpc.add_TrackingServiceServicer_to_server(service, server)
  50. else:
  51. server_grpc.add_ServerServiceServicer_to_server(service, server)
  52. server.add_insecure_port('[::]:{}'.format(port))
  53. server.start()
  54. server.wait_for_termination()
  55. def endpoint(host, port):
  56. """Format endpoint.
  57. Args:
  58. host (str): Host address.
  59. port (int): Port number.
  60. Returns:
  61. str: Address in `host:port` format.
  62. """
  63. return "{}:{}".format(host, port)