base.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from torch.utils.data import Dataset as Dataset_
  2. from federatedml.nn.backend.utils.common import ML_PATH
  3. import importlib
  4. import abc
  5. import numpy as np
  6. class Dataset(Dataset_):
  7. def __init__(self, **kwargs):
  8. super(Dataset, self).__init__()
  9. self._type = 'local' # train/predict
  10. self._check = False
  11. self._generated_ids = None
  12. self.training = True
  13. @property
  14. def dataset_type(self):
  15. if not hasattr(self, '_type'):
  16. raise AttributeError(
  17. 'type variable not exists, call __init__ of super class')
  18. return self._type
  19. @dataset_type.setter
  20. def dataset_type(self, val):
  21. self._type = val
  22. def has_dataset_type(self):
  23. return self.dataset_type
  24. def set_type(self, _type):
  25. self.dataset_type = _type
  26. def get_type(self):
  27. return self.dataset_type
  28. def has_sample_ids(self):
  29. # if not implement get_sample_ids, return False
  30. try:
  31. sample_ids = self.get_sample_ids()
  32. except NotImplementedError as e:
  33. return False
  34. except BaseException as e:
  35. raise e
  36. if sample_ids is None:
  37. return False
  38. else:
  39. if not self._check:
  40. assert isinstance(
  41. sample_ids, list), 'get_sample_ids() must return a list contains str or integer'
  42. for id_ in sample_ids:
  43. if (not isinstance(id_, str)) and (not isinstance(id_, int)):
  44. raise RuntimeError(
  45. 'get_sample_ids() must return a list contains str or integer: got id of type {}:{}'.format(
  46. id_, type(id_)))
  47. assert len(sample_ids) == len(
  48. self), 'sample id len:{} != dataset length:{}'.format(len(sample_ids), len(self))
  49. self._check = True
  50. return True
  51. def init_sid_and_getfunc(self, prefix: str = None):
  52. if prefix is not None:
  53. assert isinstance(
  54. prefix, str), 'prefix must be a str, but got {}'.format(prefix)
  55. else:
  56. prefix = self._type
  57. generated_ids = []
  58. for i in range(0, self.__len__()):
  59. generated_ids.append(prefix + '_' + str(i))
  60. self._generated_ids = generated_ids
  61. def get_func():
  62. return self._generated_ids
  63. self.get_sample_ids = get_func
  64. """
  65. Functions for users
  66. """
  67. def train(self, ):
  68. self.training = True
  69. def eval(self, ):
  70. self.training = False
  71. # Function to implemented
  72. @abc.abstractmethod
  73. def load(self, file_path):
  74. raise NotImplementedError(
  75. 'You must implement load function so that Client can pass file-path to this '
  76. 'class')
  77. def __getitem__(self, item):
  78. raise NotImplementedError()
  79. def __len__(self):
  80. raise NotImplementedError()
  81. def get_classes(self):
  82. raise NotImplementedError()
  83. def get_sample_ids(self):
  84. raise NotImplementedError()
  85. class ShuffleWrapDataset(Dataset_):
  86. def __init__(self, dataset: Dataset, shuffle_seed=100):
  87. super(ShuffleWrapDataset, self).__init__()
  88. self.ds = dataset
  89. ids = self.ds.get_sample_ids()
  90. sort_idx = np.argsort(np.array(ids))
  91. assert isinstance(dataset, Dataset)
  92. self.idx = sort_idx
  93. if shuffle_seed is not None:
  94. np.random.seed(shuffle_seed)
  95. self.shuffled_idx = np.copy(self.idx)
  96. np.random.shuffle(self.shuffled_idx)
  97. else:
  98. self.shuffled_idx = np.copy(self.idx)
  99. self.idx_map = {k: v for k, v in zip(self.idx, self.shuffled_idx)}
  100. def train(self, ):
  101. self.ds.train()
  102. def eval(self, ):
  103. self.ds.eval()
  104. def __getitem__(self, item):
  105. return self.ds[self.idx_map[self.idx[item]]]
  106. def __len__(self):
  107. return len(self.ds)
  108. def __repr__(self):
  109. return self.ds.__repr__()
  110. def has_sample_ids(self):
  111. return self.ds.has_sample_ids()
  112. def set_shuffled_idx(self, idx_map: dict):
  113. self.shuffled_idx = np.array(list(idx_map.values()))
  114. self.idx_map = idx_map
  115. def get_sample_ids(self):
  116. ids = self.ds.get_sample_ids()
  117. return np.array(ids)[self.shuffled_idx].tolist()
  118. def get_classes(self):
  119. return self.ds.get_classes()
  120. def get_dataset_class(dataset_module_name: str):
  121. if dataset_module_name.endswith('.py'):
  122. dataset_module_name = dataset_module_name.replace('.py', '')
  123. ds_modules = importlib.import_module(
  124. '{}.dataset.{}'.format(
  125. ML_PATH, dataset_module_name))
  126. try:
  127. for k, v in ds_modules.__dict__.items():
  128. if isinstance(v, type):
  129. if issubclass(v, Dataset) and v is not Dataset:
  130. return v
  131. raise ValueError('Did not find any class in {}.py that is the subclass of Dataset class'.
  132. format(dataset_module_name))
  133. except ValueError as e:
  134. raise e