ftl_dataloder.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import numpy as np
  2. import tensorflow as tf
  3. from federatedml.util import LOGGER
  4. class FTLDataLoader(tf.keras.utils.Sequence):
  5. def __init__(self, non_overlap_samples, overlap_samples, batch_size, guest_side=True):
  6. self.batch_size = batch_size
  7. self.guest_side = guest_side
  8. self._overlap_index = []
  9. self._non_overlap_index = []
  10. if guest_side:
  11. self.size = non_overlap_samples.count() + overlap_samples.count()
  12. else:
  13. self.size = overlap_samples.count()
  14. _, one_data = overlap_samples.first()
  15. self.y_shape = (1,)
  16. self.x_shape = one_data.features.shape
  17. self.x = np.zeros((self.size, *self.x_shape))
  18. self.y = np.zeros((self.size, *self.y_shape))
  19. index = 0
  20. self._overlap_keys = []
  21. self._non_overlap_keys = []
  22. for k, inst in overlap_samples.collect():
  23. self._overlap_keys.append(k)
  24. self.x[index] = inst.features
  25. if guest_side:
  26. self.y[index] = inst.label
  27. index += 1
  28. if self.guest_side:
  29. for k, inst in non_overlap_samples.collect():
  30. self._non_overlap_keys.append(k)
  31. self.x[index] = inst.features
  32. if guest_side:
  33. self.y[index] = inst.label
  34. index += 1
  35. if guest_side:
  36. self._overlap_index = np.array(list(range(0, overlap_samples.count())))
  37. self._non_overlap_index = np.array(list(range(overlap_samples.count(), self.size)))
  38. else:
  39. self._overlap_index = list(range(len(self.x)))
  40. def get_overlap_indexes(self):
  41. return self._overlap_index
  42. def get_non_overlap_indexes(self):
  43. return self._non_overlap_index
  44. def get_batch_indexes(self, batch_index):
  45. start = self.batch_size * batch_index
  46. end = self.batch_size * (batch_index + 1)
  47. return start, end
  48. def get_relative_overlap_index(self, batch_index):
  49. start, end = self.get_batch_indexes(batch_index)
  50. return self._overlap_index[(self._overlap_index >= start) & (self._overlap_index < end)] % self.batch_size
  51. def get_overlap_x(self):
  52. return self.x[self._overlap_index]
  53. def get_overlap_y(self):
  54. return self.y[self._overlap_index]
  55. def get_overlap_keys(self):
  56. return self._overlap_keys
  57. def get_non_overlap_keys(self):
  58. return self._non_overlap_keys
  59. def __getitem__(self, index):
  60. start, end = self.get_batch_indexes(index)
  61. if self.guest_side:
  62. return self.x[start: end], self.y[start: end]
  63. else:
  64. return self.x[start: end]
  65. def __len__(self):
  66. return int(np.ceil(self.size / float(self.batch_size)))
  67. def get_idx(self):
  68. return self._keys
  69. def data_basic_info(self):
  70. return 'total sample num is {}, overlap sample num is {}, non_overlap sample is {},'\
  71. 'x_shape is {}'.format(self.size, len(self._overlap_index), len(self._non_overlap_index),
  72. self.x_shape)