service.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import logging
  2. import threading
  3. from easyfl.pb import server_service_pb2_grpc as server_grpc, server_service_pb2 as server_pb, common_pb2 as common_pb
  4. from easyfl.protocol import codec
  5. from easyfl.tracking import metric
  6. logger = logging.getLogger(__name__)
  7. class ServerService(server_grpc.ServerServiceServicer):
  8. """"Remote gRPC server service.
  9. Args:
  10. server (:obj:`BaseServer`): Federated learning server instance.
  11. """
  12. def __init__(self, server):
  13. self._base = server
  14. self._clients_per_round = 0
  15. self._train_client_count = 0
  16. self._uploaded_models = {}
  17. self._uploaded_weights = {}
  18. self._uploaded_metrics = []
  19. self._test_client_count = 0
  20. self._accuracies = []
  21. self._losses = []
  22. self._test_sizes = []
  23. def Run(self, request, context):
  24. """Trigger federated learning process."""
  25. response = server_pb.RunResponse(
  26. status=common_pb.Status(code=common_pb.SC_OK),
  27. )
  28. if self._base.is_training():
  29. response = server_pb.RunResponse(
  30. status=common_pb.Status(
  31. code=common_pb.SC_ALREADY_EXISTS,
  32. message="Training in progress, please stop current training or wait for completion",
  33. ),
  34. )
  35. else:
  36. model = codec.unmarshal(request.model)
  37. self._base.start_remote_training(model, request.clients)
  38. return response
  39. def Stop(self, request, context):
  40. """Stop federated learning process."""
  41. response = server_pb.StopResponse(
  42. status=common_pb.Status(code=common_pb.SC_OK),
  43. )
  44. if self._base.is_training():
  45. self._base.stop()
  46. else:
  47. response = server_pb.RunResponse(
  48. status=common_pb.Status(
  49. code=common_pb.SC_NOT_FOUND,
  50. message="No existing training",
  51. ),
  52. )
  53. return response
  54. def Upload(self, request, context):
  55. """Handle upload from clients."""
  56. # TODO: put train and test logic in a separate thread and add thread lock to ensure atomicity.
  57. t = threading.Thread(target=self._handle_upload, args=[request, context])
  58. t.start()
  59. response = server_pb.UploadResponse(
  60. status=common_pb.Status(code=common_pb.SC_OK),
  61. )
  62. return response
  63. def _handle_upload(self, request, context):
  64. # if not self._base.upload_event.is_set():
  65. data = codec.unmarshal(request.content.data)
  66. data_size = request.content.data_size
  67. client_metric = metric.ClientMetric.from_proto(request.content.metric)
  68. clients_per_round = self._base.conf.server.clients_per_round
  69. num_of_clients = self._base.num_of_clients()
  70. if num_of_clients < clients_per_round:
  71. # TODO: use a more appropriate way to handle this situation
  72. logger.warning(
  73. "Available number of clients {} is smaller than clients per round {}".format(num_of_clients,
  74. clients_per_round))
  75. self._clients_per_round = num_of_clients
  76. else:
  77. self._clients_per_round = clients_per_round
  78. if request.content.type == common_pb.DATA_TYPE_PARAMS:
  79. self._handle_upload_train(request.client_id, data, data_size, client_metric)
  80. elif request.content.type == common_pb.DATA_TYPE_PERFORMANCE:
  81. self._handle_upload_test(data, data_size, client_metric)
  82. def _handle_upload_train(self, client_id, data, data_size, client_metric):
  83. model = self._base.decompression(data)
  84. self._uploaded_models[client_id] = model
  85. self._uploaded_weights[client_id] = data_size
  86. self._uploaded_metrics.append(client_metric)
  87. self._train_client_count += 1
  88. self._trigger_aggregate_train()
  89. def _handle_upload_test(self, data, data_size, client_metric):
  90. self._accuracies.append(data.accuracy)
  91. self._losses.append(data.loss)
  92. self._test_sizes.append(data_size)
  93. self._uploaded_metrics.append(client_metric)
  94. self._test_client_count += 1
  95. self._trigger_aggregate_test()
  96. def _trigger_aggregate_train(self):
  97. logger.info("train_client_count: {}/{}".format(self._train_client_count, self._clients_per_round))
  98. if self._train_client_count == self._clients_per_round:
  99. self._base.set_client_uploads_train(self._uploaded_models, self._uploaded_weights, self._uploaded_metrics)
  100. self._train_client_count = 0
  101. self._reset_train_cache()
  102. with self._base.condition():
  103. self._base.notify_all()
  104. def _trigger_aggregate_test(self):
  105. # TODO: determine the testing clients not only by the selected number of clients
  106. if self._test_client_count == self._clients_per_round:
  107. self._base.set_client_uploads_test(self._accuracies, self._losses, self._test_sizes, self._uploaded_metrics)
  108. self._test_client_count = 0
  109. self._reset_test_cache()
  110. with self._base.condition():
  111. self._base.notify_all()
  112. def _reset_train_cache(self):
  113. self._uploaded_models = {}
  114. self._uploaded_weights = {}
  115. self._uploaded_metrics = []
  116. def _reset_test_cache(self):
  117. self._accuracies = []
  118. self._losses = []
  119. self._test_sizes = []
  120. self._uploaded_metrics = []