batch_info_sync.py 4.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from federatedml.util import LOGGER
  18. from federatedml.util import consts
  19. class Guest(object):
  20. def _register_batch_data_index_transfer(self, batch_data_info_transfer,
  21. batch_data_index_transfer,
  22. batch_validate_info_transfer,
  23. has_arbiter):
  24. self.batch_data_info_transfer = batch_data_info_transfer.disable_auto_clean()
  25. self.batch_data_index_transfer = batch_data_index_transfer.disable_auto_clean()
  26. self.batch_validate_info_transfer = batch_validate_info_transfer
  27. self.has_arbiter = has_arbiter
  28. def sync_batch_info(self, batch_info, suffix=tuple()):
  29. self.batch_data_info_transfer.remote(obj=batch_info,
  30. role=consts.HOST,
  31. suffix=suffix)
  32. if self.has_arbiter:
  33. self.batch_data_info_transfer.remote(obj=batch_info,
  34. role=consts.ARBITER,
  35. suffix=suffix)
  36. def sync_batch_index(self, batch_index, suffix=tuple()):
  37. self.batch_data_index_transfer.remote(obj=batch_index,
  38. role=consts.HOST,
  39. suffix=suffix)
  40. def sync_batch_validate_info(self, suffix):
  41. if not self.batch_validate_info_transfer:
  42. raise ValueError("batch_validate_info should be create in transfer variable")
  43. validate_info = self.batch_validate_info_transfer.get(idx=-1,
  44. suffix=suffix)
  45. return validate_info
  46. class Host(object):
  47. def _register_batch_data_index_transfer(self, batch_data_info_transfer, batch_data_index_transfer,
  48. batch_validate_info_transfer):
  49. self.batch_data_info_transfer = batch_data_info_transfer.disable_auto_clean()
  50. self.batch_data_index_transfer = batch_data_index_transfer.disable_auto_clean()
  51. self.batch_validate_info_transfer = batch_validate_info_transfer
  52. def sync_batch_info(self, suffix=tuple()):
  53. LOGGER.debug("In sync_batch_info, suffix is :{}".format(suffix))
  54. batch_info = self.batch_data_info_transfer.get(idx=0,
  55. suffix=suffix)
  56. batch_size = batch_info.get('batch_size')
  57. if batch_size < consts.MIN_BATCH_SIZE and batch_size != -1:
  58. raise ValueError(
  59. "Batch size get from guest should not less than {}, except -1, batch_size is {}".format(
  60. consts.MIN_BATCH_SIZE, batch_size))
  61. return batch_info
  62. def sync_batch_index(self, suffix=tuple()):
  63. batch_index = self.batch_data_index_transfer.get(idx=0,
  64. suffix=suffix)
  65. return batch_index
  66. def sync_batch_validate_info(self, validate_info, suffix=tuple()):
  67. self.batch_validate_info_transfer.remote(obj=validate_info,
  68. role=consts.GUEST,
  69. suffix=suffix)
  70. class Arbiter(object):
  71. def _register_batch_data_index_transfer(self, batch_data_info_transfer, batch_data_index_transfer):
  72. self.batch_data_info_transfer = batch_data_info_transfer.disable_auto_clean()
  73. self.batch_data_index_transfer = batch_data_index_transfer.disable_auto_clean()
  74. def sync_batch_info(self, suffix=tuple()):
  75. batch_info = self.batch_data_info_transfer.get(idx=0,
  76. suffix=suffix)
  77. return batch_info