mini_batch.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import random
  17. from fate_arch.session import computing_session as session
  18. from federatedml.model_selection import indices
  19. from federatedml.util import LOGGER
  20. class MiniBatch:
  21. def __init__(self, data_inst, batch_size=320, shuffle=False, batch_strategy="full", masked_rate=0):
  22. self.batch_data_sids = None
  23. self.batch_nums = 0
  24. self.data_inst = data_inst
  25. self.all_batch_data = None
  26. self.all_index_data = None
  27. self.data_sids_iter = None
  28. self.batch_data_generator = None
  29. self.batch_mutable = False
  30. self.batch_masked = False
  31. if batch_size == -1:
  32. self.batch_size = data_inst.count()
  33. else:
  34. self.batch_size = batch_size
  35. self.__init_mini_batch_data_seperator(data_inst, self.batch_size, batch_strategy, masked_rate, shuffle)
  36. def mini_batch_data_generator(self, result='data'):
  37. """
  38. Generate mini-batch data or index
  39. Parameters
  40. ----------
  41. result : str, 'data' or 'index', default: 'data'
  42. Specify you want batch data or batch index.
  43. Returns
  44. -------
  45. A generator that might generate data or index.
  46. """
  47. LOGGER.debug("Currently, batch_num is: {}".format(self.batch_nums))
  48. if result == 'index':
  49. for index_table in self.all_index_data:
  50. yield index_table
  51. elif result == "data":
  52. for batch_data in self.all_batch_data:
  53. yield batch_data
  54. else:
  55. for batch_data, index_table in zip(self.all_batch_data, self.all_index_data):
  56. yield batch_data, index_table
  57. # if self.batch_mutable:
  58. # self.__generate_batch_data()
  59. def __init_mini_batch_data_seperator(self, data_insts, batch_size, batch_strategy, masked_rate, shuffle):
  60. self.data_sids_iter, data_size = indices.collect_index(data_insts)
  61. self.batch_data_generator = get_batch_generator(
  62. data_size, batch_size, batch_strategy, masked_rate, shuffle=shuffle)
  63. self.batch_nums = self.batch_data_generator.batch_nums
  64. self.batch_mutable = self.batch_data_generator.batch_mutable()
  65. self.masked_batch_size = self.batch_data_generator.masked_batch_size
  66. if self.batch_mutable is False:
  67. self.__generate_batch_data()
  68. def generate_batch_data(self):
  69. if self.batch_mutable:
  70. self.__generate_batch_data()
  71. def __generate_batch_data(self):
  72. self.all_index_data, self.all_batch_data = self.batch_data_generator.generate_data(
  73. self.data_inst, self.data_sids_iter)
  74. def get_batch_generator(data_size, batch_size, batch_strategy, masked_rate, shuffle):
  75. if batch_size >= data_size:
  76. LOGGER.warning("As batch_size >= data size, all batch strategy will be disabled")
  77. return FullBatchDataGenerator(data_size, data_size, shuffle=False)
  78. # if round((masked_rate + 1) * batch_size) >= data_size:
  79. # LOGGER.warning("Masked dataset's batch_size >= data size, batch shuffle will be disabled")
  80. # return FullBatchDataGenerator(data_size, data_size, shuffle=False, masked_rate=masked_rate)
  81. if batch_strategy == "full":
  82. if masked_rate > 0:
  83. LOGGER.warning("If using full batch strategy and masked rate > 0, shuffle will always be true")
  84. shuffle = True
  85. return FullBatchDataGenerator(data_size, batch_size, shuffle=shuffle, masked_rate=masked_rate)
  86. else:
  87. if shuffle:
  88. LOGGER.warning("if use random select batch strategy, shuffle will not work")
  89. return RandomBatchDataGenerator(data_size, batch_size, masked_rate)
  90. class BatchDataGenerator(object):
  91. def __init__(self, data_size, batch_size, shuffle=False, masked_rate=0):
  92. self.batch_nums = None
  93. self.masked_batch_size = min(data_size, round((1 + masked_rate) * batch_size))
  94. self.batch_size = batch_size
  95. self.shuffle = shuffle
  96. def batch_mutable(self):
  97. return True
  98. @staticmethod
  99. def _generate_batch_data_with_batch_ids(data_insts, batch_ids, masked_ids=None):
  100. batch_index_table = session.parallelize(batch_ids,
  101. include_key=True,
  102. partition=data_insts.partitions)
  103. batch_data_table = batch_index_table.join(data_insts, lambda x, y: y)
  104. if masked_ids:
  105. masked_index_table = session.parallelize(masked_ids,
  106. include_key=True,
  107. partition=data_insts.partitions)
  108. return masked_index_table, batch_data_table
  109. else:
  110. return batch_index_table, batch_data_table
  111. class FullBatchDataGenerator(BatchDataGenerator):
  112. def __init__(self, data_size, batch_size, shuffle=False, masked_rate=0):
  113. super(FullBatchDataGenerator, self).__init__(data_size, batch_size, shuffle, masked_rate=masked_rate)
  114. self.batch_nums = (data_size + batch_size - 1) // batch_size
  115. LOGGER.debug(f"Init Full Batch Data Generator, batch_nums: {self.batch_nums}, batch_size: {self.batch_size}, "
  116. f"masked_batch_size: {self.masked_batch_size}, shuffle: {self.shuffle}")
  117. def generate_data(self, data_insts, data_sids):
  118. if self.shuffle:
  119. random.SystemRandom().shuffle(data_sids)
  120. index_table = []
  121. batch_data = []
  122. if self.batch_size != self.masked_batch_size:
  123. for bid in range(self.batch_nums):
  124. batch_ids = data_sids[bid * self.batch_size:(bid + 1) * self.batch_size]
  125. masked_ids_set = set()
  126. for sid, _ in batch_ids:
  127. masked_ids_set.add(sid)
  128. possible_ids = random.SystemRandom().sample(data_sids, self.masked_batch_size)
  129. for pid, _ in possible_ids:
  130. if pid not in masked_ids_set:
  131. masked_ids_set.add(pid)
  132. if len(masked_ids_set) == self.masked_batch_size:
  133. break
  134. masked_ids = zip(list(masked_ids_set), [None] * len(masked_ids_set))
  135. masked_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts,
  136. batch_ids,
  137. masked_ids)
  138. index_table.append(masked_index_table)
  139. batch_data.append(batch_data_table)
  140. else:
  141. for bid in range(self.batch_nums):
  142. batch_ids = data_sids[bid * self.batch_size: (bid + 1) * self.batch_size]
  143. batch_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts, batch_ids)
  144. index_table.append(batch_index_table)
  145. batch_data.append(batch_data_table)
  146. return index_table, batch_data
  147. def batch_mutable(self):
  148. return self.masked_batch_size > self.batch_size or self.shuffle
  149. class RandomBatchDataGenerator(BatchDataGenerator):
  150. def __init__(self, data_size, batch_size, masked_rate=0):
  151. super(RandomBatchDataGenerator, self).__init__(data_size, batch_size, shuffle=False, masked_rate=masked_rate)
  152. self.batch_nums = 1
  153. LOGGER.debug(f"Init Random Batch Data Generator, batch_nums: {self.batch_nums}, batch_size: {self.batch_size}, "
  154. f"masked_batch_size: {self.masked_batch_size}")
  155. def generate_data(self, data_insts, data_sids):
  156. if self.masked_batch_size == self.batch_size:
  157. batch_ids = random.SystemRandom().sample(data_sids, self.batch_size)
  158. batch_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts, batch_ids)
  159. batch_data_table = batch_index_table.join(data_insts, lambda x, y: y)
  160. return [batch_index_table], [batch_data_table]
  161. else:
  162. masked_ids = random.SystemRandom().sample(data_sids, self.masked_batch_size)
  163. batch_ids = masked_ids[: self.batch_size]
  164. masked_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts,
  165. batch_ids,
  166. masked_ids)
  167. return [masked_index_table], [batch_data_table]