batch_generator.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from federatedml.framework.hetero.sync import batch_info_sync
  18. from federatedml.model_selection import MiniBatch
  19. from federatedml.util import LOGGER
  20. class Guest(batch_info_sync.Guest):
  21. def __init__(self):
  22. self.mini_batch_obj = None
  23. self.finish_sycn = False
  24. self.batch_nums = None
  25. self.batch_masked = False
  26. def register_batch_generator(self, transfer_variables, has_arbiter=True):
  27. self._register_batch_data_index_transfer(transfer_variables.batch_info,
  28. transfer_variables.batch_data_index,
  29. getattr(transfer_variables, "batch_validate_info", None),
  30. has_arbiter)
  31. def initialize_batch_generator(self, data_instances, batch_size, suffix=tuple(),
  32. shuffle=False, batch_strategy="full", masked_rate=0):
  33. self.mini_batch_obj = MiniBatch(data_instances, batch_size=batch_size, shuffle=shuffle,
  34. batch_strategy=batch_strategy, masked_rate=masked_rate)
  35. self.batch_nums = self.mini_batch_obj.batch_nums
  36. self.batch_masked = self.mini_batch_obj.batch_size != self.mini_batch_obj.masked_batch_size
  37. batch_info = {"batch_size": self.mini_batch_obj.batch_size, "batch_num": self.batch_nums,
  38. "batch_mutable": self.mini_batch_obj.batch_mutable,
  39. "masked_batch_size": self.mini_batch_obj.masked_batch_size}
  40. self.sync_batch_info(batch_info, suffix)
  41. if not self.mini_batch_obj.batch_mutable:
  42. self.prepare_batch_data(suffix)
  43. def prepare_batch_data(self, suffix=tuple()):
  44. self.mini_batch_obj.generate_batch_data()
  45. index_generator = self.mini_batch_obj.mini_batch_data_generator(result='index')
  46. batch_index = 0
  47. for batch_data_index in index_generator:
  48. batch_suffix = suffix + (batch_index,)
  49. self.sync_batch_index(batch_data_index, batch_suffix)
  50. batch_index += 1
  51. def generate_batch_data(self, with_index=False, suffix=tuple()):
  52. if self.mini_batch_obj.batch_mutable:
  53. self.prepare_batch_data(suffix)
  54. if with_index:
  55. data_generator = self.mini_batch_obj.mini_batch_data_generator(result='both')
  56. for batch_data, index_data in data_generator:
  57. yield batch_data, index_data
  58. else:
  59. data_generator = self.mini_batch_obj.mini_batch_data_generator(result='data')
  60. for batch_data in data_generator:
  61. yield batch_data
  62. def verify_batch_legality(self, suffix=tuple()):
  63. validate_infos = self.sync_batch_validate_info(suffix)
  64. least_batch_size = 0
  65. is_legal = True
  66. for validate_info in validate_infos:
  67. legality = validate_info.get("legality")
  68. if not legality:
  69. is_legal = False
  70. least_batch_size = max(least_batch_size, validate_info.get("least_batch_size"))
  71. if not is_legal:
  72. raise ValueError(f"To use batch masked strategy, "
  73. f"(masked_rate + 1) * batch_size should > {least_batch_size}")
  74. class Host(batch_info_sync.Host):
  75. def __init__(self):
  76. self.finish_sycn = False
  77. self.batch_data_insts = []
  78. self.batch_nums = None
  79. self.data_inst = None
  80. self.batch_mutable = False
  81. self.batch_masked = False
  82. self.masked_batch_size = None
  83. def register_batch_generator(self, transfer_variables, has_arbiter=None):
  84. self._register_batch_data_index_transfer(transfer_variables.batch_info,
  85. transfer_variables.batch_data_index,
  86. getattr(transfer_variables, "batch_validate_info", None))
  87. def initialize_batch_generator(self, data_instances, suffix=tuple(), **kwargs):
  88. batch_info = self.sync_batch_info(suffix)
  89. batch_size = batch_info.get("batch_size")
  90. self.batch_nums = batch_info.get('batch_num')
  91. self.batch_mutable = batch_info.get("batch_mutable")
  92. self.masked_batch_size = batch_info.get("masked_batch_size")
  93. self.batch_masked = self.masked_batch_size != batch_size
  94. if not self.batch_mutable:
  95. self.prepare_batch_data(data_instances, suffix)
  96. else:
  97. self.data_inst = data_instances
  98. def prepare_batch_data(self, data_inst, suffix=tuple()):
  99. self.batch_data_insts = []
  100. for batch_index in range(self.batch_nums):
  101. batch_suffix = suffix + (batch_index,)
  102. batch_data_index = self.sync_batch_index(suffix=batch_suffix)
  103. # batch_data_inst = batch_data_index.join(data_instances, lambda g, d: d)
  104. batch_data_inst = data_inst.join(batch_data_index, lambda d, g: d)
  105. self.batch_data_insts.append(batch_data_inst)
  106. def generate_batch_data(self, suffix=tuple()):
  107. if self.batch_mutable:
  108. self.prepare_batch_data(data_inst=self.data_inst, suffix=suffix)
  109. batch_index = 0
  110. for batch_data_inst in self.batch_data_insts:
  111. LOGGER.info("batch_num: {}, batch_data_inst size:{}".format(
  112. batch_index, batch_data_inst.count()))
  113. yield batch_data_inst
  114. batch_index += 1
  115. def verify_batch_legality(self, least_batch_size, suffix=tuple()):
  116. if self.masked_batch_size <= least_batch_size:
  117. batch_validate_info = {"legality": False,
  118. "least_batch_size": least_batch_size}
  119. LOGGER.warning(f"masked_batch_size {self.masked_batch_size} is illegal, should > {least_batch_size}")
  120. else:
  121. batch_validate_info = {"legality": True}
  122. self.sync_batch_validate_info(batch_validate_info, suffix)
  123. class Arbiter(batch_info_sync.Arbiter):
  124. def __init__(self):
  125. self.batch_num = None
  126. def register_batch_generator(self, transfer_variables):
  127. self._register_batch_data_index_transfer(transfer_variables.batch_info, transfer_variables.batch_data_index)
  128. def initialize_batch_generator(self, suffix=tuple()):
  129. batch_info = self.sync_batch_info(suffix)
  130. self.batch_num = batch_info.get('batch_num')
  131. def generate_batch_data(self):
  132. for i in range(self.batch_num):
  133. yield i