123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import argparse
- import logging
- from easyfl.communication import grpc_wrapper
- from easyfl.pb import common_pb2 as common_pb
- from easyfl.pb import tracking_service_pb2 as tracking_pb
- from easyfl.pb import tracking_service_pb2_grpc as tracking_grpc
- from easyfl.tracking import metric
- from easyfl.tracking.storage import SqliteStorage
- logger = logging.getLogger(__name__)
- def create_argument_parser():
- parser = argparse.ArgumentParser(description='Federated Tracker')
- parser.add_argument('--local-port',
- type=int,
- default=12666,
- help='Listen port of the client')
- return parser
- class TrackingService(tracking_grpc.TrackingServiceServicer):
- def __init__(self, storage=SqliteStorage):
- self._storage = storage()
- logger.info("Tracking service is online")
- self._storage.setup()
- def TrackTaskMetric(self, request, context):
- response = tracking_pb.TrackTaskMetricResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- self._storage.store_task_metric(metric.TaskMetric.from_proto(request.task_metric))
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = f"Failed to track task metric, err: {e}"
- logger.error(response.status.message)
- return response
- def TrackRoundMetric(self, request, context):
- response = tracking_pb.TrackRoundMetricResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- self._storage.store_round_metric(metric.RoundMetric.from_proto(request.round_metric))
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = f"Failed to track round metric, err: {e}"
- logger.error(response.status.message)
- return response
- def TrackClientMetric(self, request, context):
- response = tracking_pb.TrackClientMetricResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- metrics = [metric.ClientMetric.from_proto(m) for m in request.client_metrics]
- self._storage.store_client_metrics(metrics)
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = f"Failed to track client metric, err: {e}"
- logger.error(response.status.message)
- return response
- def TrackClientTrainMetric(self, request, context):
- response = tracking_pb.TrackClientTrainMetricResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- self._storage.store_client_train_metric(request.task_id,
- request.round_id,
- request.client_id,
- request.train_loss,
- request.train_time,
- request.train_upload_time,
- request.train_download_size,
- request.train_upload_size)
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = "Tracking client train failed, err: {}".format(e)
- logger.error("Tracking client train failed, err: {}".format(e))
- return response
- def TrackClientTestMetric(self, request, context):
- response = tracking_pb.TrackClientTestMetricResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- self._storage.store_client_test_metric(request.task_id,
- request.round_id,
- request.client_id,
- request.test_accuracy,
- request.test_loss,
- request.test_time,
- request.test_upload_time,
- request.test_download_size)
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = "Tracking client test failed, err: {}".format(e)
- logger.error("Tracking client test failed, err: {}".format(e))
- return response
- def GetRoundTrainTestTime(self, request, context):
- response = tracking_pb.GetRoundTrainTestTimeResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- resp = self._storage.get_round_train_test_time(request.task_id,
- request.rounds,
- request.interval)
- for i in resp:
- train_test_time = tracking_pb.TrainTestTime(round_id=i[0], time=i[1])
- response.train_test_times.append(train_test_time)
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = "get round train_test time failed, err: {}".format(e)
- logger.error("get round train_test time failed, err: {}".format(e))
- return response
- def GetRoundMetrics(self, request, context):
- response = tracking_pb.GetRoundMetricsResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- resp = self._storage.get_round_metrics(request.task_id, request.rounds)
- response.metrics = [metric.RoundMetric.from_sql(r) for r in resp]
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = f"Failed to get round metrics, err: {e}"
- logger.error(response.status.message)
- return response
- def GetClientMetrics(self, request, context):
- response = tracking_pb.GetClientMetricsResponse(
- status=common_pb.Status(code=common_pb.SC_OK),
- )
- try:
- resp = self._storage.get_client_metrics(request.task_id, request.round_id, request.client_ids)
- response.metrics = [metric.ClientMetric.from_sql(r) for r in resp]
- except Exception as e:
- response.status.code = common_pb.SC_UNKNOWN
- response.status.message = f"Failed to get client metrics failed, err: {e}"
- logger.error(response.status.message)
- return response
- def start_tracking_service(local_port=12666):
- logger.info("Tracking GRPC server started at :{}".format(local_port))
- grpc_wrapper.start_service(grpc_wrapper.TYPE_TRACKING, TrackingService(), local_port)
|