1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import numpy as np
- import tensorflow as tf
- from federatedml.util import LOGGER
- class FTLDataLoader(tf.keras.utils.Sequence):
- def __init__(self, non_overlap_samples, overlap_samples, batch_size, guest_side=True):
- self.batch_size = batch_size
- self.guest_side = guest_side
- self._overlap_index = []
- self._non_overlap_index = []
- if guest_side:
- self.size = non_overlap_samples.count() + overlap_samples.count()
- else:
- self.size = overlap_samples.count()
- _, one_data = overlap_samples.first()
- self.y_shape = (1,)
- self.x_shape = one_data.features.shape
- self.x = np.zeros((self.size, *self.x_shape))
- self.y = np.zeros((self.size, *self.y_shape))
- index = 0
- self._overlap_keys = []
- self._non_overlap_keys = []
- for k, inst in overlap_samples.collect():
- self._overlap_keys.append(k)
- self.x[index] = inst.features
- if guest_side:
- self.y[index] = inst.label
- index += 1
- if self.guest_side:
- for k, inst in non_overlap_samples.collect():
- self._non_overlap_keys.append(k)
- self.x[index] = inst.features
- if guest_side:
- self.y[index] = inst.label
- index += 1
- if guest_side:
- self._overlap_index = np.array(list(range(0, overlap_samples.count())))
- self._non_overlap_index = np.array(list(range(overlap_samples.count(), self.size)))
- else:
- self._overlap_index = list(range(len(self.x)))
- def get_overlap_indexes(self):
- return self._overlap_index
- def get_non_overlap_indexes(self):
- return self._non_overlap_index
- def get_batch_indexes(self, batch_index):
- start = self.batch_size * batch_index
- end = self.batch_size * (batch_index + 1)
- return start, end
- def get_relative_overlap_index(self, batch_index):
- start, end = self.get_batch_indexes(batch_index)
- return self._overlap_index[(self._overlap_index >= start) & (self._overlap_index < end)] % self.batch_size
- def get_overlap_x(self):
- return self.x[self._overlap_index]
- def get_overlap_y(self):
- return self.y[self._overlap_index]
- def get_overlap_keys(self):
- return self._overlap_keys
- def get_non_overlap_keys(self):
- return self._non_overlap_keys
- def __getitem__(self, index):
- start, end = self.get_batch_indexes(index)
- if self.guest_side:
- return self.x[start: end], self.y[start: end]
- else:
- return self.x[start: end]
- def __len__(self):
- return int(np.ceil(self.size / float(self.batch_size)))
- def get_idx(self):
- return self._keys
- def data_basic_info(self):
- return 'total sample num is {}, overlap sample num is {}, non_overlap sample is {},'\
- 'x_shape is {}'.format(self.size, len(self._overlap_index), len(self._non_overlap_index),
- self.x_shape)
|