123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import random
- from fate_arch.session import computing_session as session
- from federatedml.model_selection import indices
- from federatedml.util import LOGGER
- class MiniBatch:
- def __init__(self, data_inst, batch_size=320, shuffle=False, batch_strategy="full", masked_rate=0):
- self.batch_data_sids = None
- self.batch_nums = 0
- self.data_inst = data_inst
- self.all_batch_data = None
- self.all_index_data = None
- self.data_sids_iter = None
- self.batch_data_generator = None
- self.batch_mutable = False
- self.batch_masked = False
- if batch_size == -1:
- self.batch_size = data_inst.count()
- else:
- self.batch_size = batch_size
- self.__init_mini_batch_data_seperator(data_inst, self.batch_size, batch_strategy, masked_rate, shuffle)
- def mini_batch_data_generator(self, result='data'):
- """
- Generate mini-batch data or index
- Parameters
- ----------
- result : str, 'data' or 'index', default: 'data'
- Specify you want batch data or batch index.
- Returns
- -------
- A generator that might generate data or index.
- """
- LOGGER.debug("Currently, batch_num is: {}".format(self.batch_nums))
- if result == 'index':
- for index_table in self.all_index_data:
- yield index_table
- elif result == "data":
- for batch_data in self.all_batch_data:
- yield batch_data
- else:
- for batch_data, index_table in zip(self.all_batch_data, self.all_index_data):
- yield batch_data, index_table
- # if self.batch_mutable:
- # self.__generate_batch_data()
- def __init_mini_batch_data_seperator(self, data_insts, batch_size, batch_strategy, masked_rate, shuffle):
- self.data_sids_iter, data_size = indices.collect_index(data_insts)
- self.batch_data_generator = get_batch_generator(
- data_size, batch_size, batch_strategy, masked_rate, shuffle=shuffle)
- self.batch_nums = self.batch_data_generator.batch_nums
- self.batch_mutable = self.batch_data_generator.batch_mutable()
- self.masked_batch_size = self.batch_data_generator.masked_batch_size
- if self.batch_mutable is False:
- self.__generate_batch_data()
- def generate_batch_data(self):
- if self.batch_mutable:
- self.__generate_batch_data()
- def __generate_batch_data(self):
- self.all_index_data, self.all_batch_data = self.batch_data_generator.generate_data(
- self.data_inst, self.data_sids_iter)
- def get_batch_generator(data_size, batch_size, batch_strategy, masked_rate, shuffle):
- if batch_size >= data_size:
- LOGGER.warning("As batch_size >= data size, all batch strategy will be disabled")
- return FullBatchDataGenerator(data_size, data_size, shuffle=False)
- # if round((masked_rate + 1) * batch_size) >= data_size:
- # LOGGER.warning("Masked dataset's batch_size >= data size, batch shuffle will be disabled")
- # return FullBatchDataGenerator(data_size, data_size, shuffle=False, masked_rate=masked_rate)
- if batch_strategy == "full":
- if masked_rate > 0:
- LOGGER.warning("If using full batch strategy and masked rate > 0, shuffle will always be true")
- shuffle = True
- return FullBatchDataGenerator(data_size, batch_size, shuffle=shuffle, masked_rate=masked_rate)
- else:
- if shuffle:
- LOGGER.warning("if use random select batch strategy, shuffle will not work")
- return RandomBatchDataGenerator(data_size, batch_size, masked_rate)
- class BatchDataGenerator(object):
- def __init__(self, data_size, batch_size, shuffle=False, masked_rate=0):
- self.batch_nums = None
- self.masked_batch_size = min(data_size, round((1 + masked_rate) * batch_size))
- self.batch_size = batch_size
- self.shuffle = shuffle
- def batch_mutable(self):
- return True
- @staticmethod
- def _generate_batch_data_with_batch_ids(data_insts, batch_ids, masked_ids=None):
- batch_index_table = session.parallelize(batch_ids,
- include_key=True,
- partition=data_insts.partitions)
- batch_data_table = batch_index_table.join(data_insts, lambda x, y: y)
- if masked_ids:
- masked_index_table = session.parallelize(masked_ids,
- include_key=True,
- partition=data_insts.partitions)
- return masked_index_table, batch_data_table
- else:
- return batch_index_table, batch_data_table
- class FullBatchDataGenerator(BatchDataGenerator):
- def __init__(self, data_size, batch_size, shuffle=False, masked_rate=0):
- super(FullBatchDataGenerator, self).__init__(data_size, batch_size, shuffle, masked_rate=masked_rate)
- self.batch_nums = (data_size + batch_size - 1) // batch_size
- LOGGER.debug(f"Init Full Batch Data Generator, batch_nums: {self.batch_nums}, batch_size: {self.batch_size}, "
- f"masked_batch_size: {self.masked_batch_size}, shuffle: {self.shuffle}")
- def generate_data(self, data_insts, data_sids):
- if self.shuffle:
- random.SystemRandom().shuffle(data_sids)
- index_table = []
- batch_data = []
- if self.batch_size != self.masked_batch_size:
- for bid in range(self.batch_nums):
- batch_ids = data_sids[bid * self.batch_size:(bid + 1) * self.batch_size]
- masked_ids_set = set()
- for sid, _ in batch_ids:
- masked_ids_set.add(sid)
- possible_ids = random.SystemRandom().sample(data_sids, self.masked_batch_size)
- for pid, _ in possible_ids:
- if pid not in masked_ids_set:
- masked_ids_set.add(pid)
- if len(masked_ids_set) == self.masked_batch_size:
- break
- masked_ids = zip(list(masked_ids_set), [None] * len(masked_ids_set))
- masked_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts,
- batch_ids,
- masked_ids)
- index_table.append(masked_index_table)
- batch_data.append(batch_data_table)
- else:
- for bid in range(self.batch_nums):
- batch_ids = data_sids[bid * self.batch_size: (bid + 1) * self.batch_size]
- batch_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts, batch_ids)
- index_table.append(batch_index_table)
- batch_data.append(batch_data_table)
- return index_table, batch_data
- def batch_mutable(self):
- return self.masked_batch_size > self.batch_size or self.shuffle
- class RandomBatchDataGenerator(BatchDataGenerator):
- def __init__(self, data_size, batch_size, masked_rate=0):
- super(RandomBatchDataGenerator, self).__init__(data_size, batch_size, shuffle=False, masked_rate=masked_rate)
- self.batch_nums = 1
- LOGGER.debug(f"Init Random Batch Data Generator, batch_nums: {self.batch_nums}, batch_size: {self.batch_size}, "
- f"masked_batch_size: {self.masked_batch_size}")
- def generate_data(self, data_insts, data_sids):
- if self.masked_batch_size == self.batch_size:
- batch_ids = random.SystemRandom().sample(data_sids, self.batch_size)
- batch_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts, batch_ids)
- batch_data_table = batch_index_table.join(data_insts, lambda x, y: y)
- return [batch_index_table], [batch_data_table]
- else:
- masked_ids = random.SystemRandom().sample(data_sids, self.masked_batch_size)
- batch_ids = masked_ids[: self.batch_size]
- masked_index_table, batch_data_table = self._generate_batch_data_with_batch_ids(data_insts,
- batch_ids,
- masked_ids)
- return [masked_index_table], [batch_data_table]
|