service.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import argparse
  2. import logging
  3. from easyfl.communication import grpc_wrapper
  4. from easyfl.pb import common_pb2 as common_pb
  5. from easyfl.pb import tracking_service_pb2 as tracking_pb
  6. from easyfl.pb import tracking_service_pb2_grpc as tracking_grpc
  7. from easyfl.tracking import metric
  8. from easyfl.tracking.storage import SqliteStorage
  9. logger = logging.getLogger(__name__)
  10. def create_argument_parser():
  11. parser = argparse.ArgumentParser(description='Federated Tracker')
  12. parser.add_argument('--local-port',
  13. type=int,
  14. default=12666,
  15. help='Listen port of the client')
  16. return parser
  17. class TrackingService(tracking_grpc.TrackingServiceServicer):
  18. def __init__(self, storage=SqliteStorage):
  19. self._storage = storage()
  20. logger.info("Tracking service is online")
  21. self._storage.setup()
  22. def TrackTaskMetric(self, request, context):
  23. response = tracking_pb.TrackTaskMetricResponse(
  24. status=common_pb.Status(code=common_pb.SC_OK),
  25. )
  26. try:
  27. self._storage.store_task_metric(metric.TaskMetric.from_proto(request.task_metric))
  28. except Exception as e:
  29. response.status.code = common_pb.SC_UNKNOWN
  30. response.status.message = f"Failed to track task metric, err: {e}"
  31. logger.error(response.status.message)
  32. return response
  33. def TrackRoundMetric(self, request, context):
  34. response = tracking_pb.TrackRoundMetricResponse(
  35. status=common_pb.Status(code=common_pb.SC_OK),
  36. )
  37. try:
  38. self._storage.store_round_metric(metric.RoundMetric.from_proto(request.round_metric))
  39. except Exception as e:
  40. response.status.code = common_pb.SC_UNKNOWN
  41. response.status.message = f"Failed to track round metric, err: {e}"
  42. logger.error(response.status.message)
  43. return response
  44. def TrackClientMetric(self, request, context):
  45. response = tracking_pb.TrackClientMetricResponse(
  46. status=common_pb.Status(code=common_pb.SC_OK),
  47. )
  48. try:
  49. metrics = [metric.ClientMetric.from_proto(m) for m in request.client_metrics]
  50. self._storage.store_client_metrics(metrics)
  51. except Exception as e:
  52. response.status.code = common_pb.SC_UNKNOWN
  53. response.status.message = f"Failed to track client metric, err: {e}"
  54. logger.error(response.status.message)
  55. return response
  56. def TrackClientTrainMetric(self, request, context):
  57. response = tracking_pb.TrackClientTrainMetricResponse(
  58. status=common_pb.Status(code=common_pb.SC_OK),
  59. )
  60. try:
  61. self._storage.store_client_train_metric(request.task_id,
  62. request.round_id,
  63. request.client_id,
  64. request.train_loss,
  65. request.train_time,
  66. request.train_upload_time,
  67. request.train_download_size,
  68. request.train_upload_size)
  69. except Exception as e:
  70. response.status.code = common_pb.SC_UNKNOWN
  71. response.status.message = "Tracking client train failed, err: {}".format(e)
  72. logger.error("Tracking client train failed, err: {}".format(e))
  73. return response
  74. def TrackClientTestMetric(self, request, context):
  75. response = tracking_pb.TrackClientTestMetricResponse(
  76. status=common_pb.Status(code=common_pb.SC_OK),
  77. )
  78. try:
  79. self._storage.store_client_test_metric(request.task_id,
  80. request.round_id,
  81. request.client_id,
  82. request.test_accuracy,
  83. request.test_loss,
  84. request.test_time,
  85. request.test_upload_time,
  86. request.test_download_size)
  87. except Exception as e:
  88. response.status.code = common_pb.SC_UNKNOWN
  89. response.status.message = "Tracking client test failed, err: {}".format(e)
  90. logger.error("Tracking client test failed, err: {}".format(e))
  91. return response
  92. def GetRoundTrainTestTime(self, request, context):
  93. response = tracking_pb.GetRoundTrainTestTimeResponse(
  94. status=common_pb.Status(code=common_pb.SC_OK),
  95. )
  96. try:
  97. resp = self._storage.get_round_train_test_time(request.task_id,
  98. request.rounds,
  99. request.interval)
  100. for i in resp:
  101. train_test_time = tracking_pb.TrainTestTime(round_id=i[0], time=i[1])
  102. response.train_test_times.append(train_test_time)
  103. except Exception as e:
  104. response.status.code = common_pb.SC_UNKNOWN
  105. response.status.message = "get round train_test time failed, err: {}".format(e)
  106. logger.error("get round train_test time failed, err: {}".format(e))
  107. return response
  108. def GetRoundMetrics(self, request, context):
  109. response = tracking_pb.GetRoundMetricsResponse(
  110. status=common_pb.Status(code=common_pb.SC_OK),
  111. )
  112. try:
  113. resp = self._storage.get_round_metrics(request.task_id, request.rounds)
  114. response.metrics = [metric.RoundMetric.from_sql(r) for r in resp]
  115. except Exception as e:
  116. response.status.code = common_pb.SC_UNKNOWN
  117. response.status.message = f"Failed to get round metrics, err: {e}"
  118. logger.error(response.status.message)
  119. return response
  120. def GetClientMetrics(self, request, context):
  121. response = tracking_pb.GetClientMetricsResponse(
  122. status=common_pb.Status(code=common_pb.SC_OK),
  123. )
  124. try:
  125. resp = self._storage.get_client_metrics(request.task_id, request.round_id, request.client_ids)
  126. response.metrics = [metric.ClientMetric.from_sql(r) for r in resp]
  127. except Exception as e:
  128. response.status.code = common_pb.SC_UNKNOWN
  129. response.status.message = f"Failed to get client metrics failed, err: {e}"
  130. logger.error(response.status.message)
  131. return response
  132. def start_tracking_service(local_port=12666):
  133. logger.info("Tracking GRPC server started at :{}".format(local_port))
  134. grpc_wrapper.start_service(grpc_wrapper.TYPE_TRACKING, TrackingService(), local_port)