123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- from torch.utils.data import Dataset as Dataset_
- from federatedml.nn.backend.utils.common import ML_PATH
- import importlib
- import abc
- import numpy as np
- class Dataset(Dataset_):
- def __init__(self, **kwargs):
- super(Dataset, self).__init__()
- self._type = 'local' # train/predict
- self._check = False
- self._generated_ids = None
- self.training = True
- @property
- def dataset_type(self):
- if not hasattr(self, '_type'):
- raise AttributeError(
- 'type variable not exists, call __init__ of super class')
- return self._type
- @dataset_type.setter
- def dataset_type(self, val):
- self._type = val
- def has_dataset_type(self):
- return self.dataset_type
- def set_type(self, _type):
- self.dataset_type = _type
- def get_type(self):
- return self.dataset_type
- def has_sample_ids(self):
- # if not implement get_sample_ids, return False
- try:
- sample_ids = self.get_sample_ids()
- except NotImplementedError as e:
- return False
- except BaseException as e:
- raise e
- if sample_ids is None:
- return False
- else:
- if not self._check:
- assert isinstance(
- sample_ids, list), 'get_sample_ids() must return a list contains str or integer'
- for id_ in sample_ids:
- if (not isinstance(id_, str)) and (not isinstance(id_, int)):
- raise RuntimeError(
- 'get_sample_ids() must return a list contains str or integer: got id of type {}:{}'.format(
- id_, type(id_)))
- assert len(sample_ids) == len(
- self), 'sample id len:{} != dataset length:{}'.format(len(sample_ids), len(self))
- self._check = True
- return True
- def init_sid_and_getfunc(self, prefix: str = None):
- if prefix is not None:
- assert isinstance(
- prefix, str), 'prefix must be a str, but got {}'.format(prefix)
- else:
- prefix = self._type
- generated_ids = []
- for i in range(0, self.__len__()):
- generated_ids.append(prefix + '_' + str(i))
- self._generated_ids = generated_ids
- def get_func():
- return self._generated_ids
- self.get_sample_ids = get_func
- """
- Functions for users
- """
- def train(self, ):
- self.training = True
- def eval(self, ):
- self.training = False
- # Function to implemented
- @abc.abstractmethod
- def load(self, file_path):
- raise NotImplementedError(
- 'You must implement load function so that Client can pass file-path to this '
- 'class')
- def __getitem__(self, item):
- raise NotImplementedError()
- def __len__(self):
- raise NotImplementedError()
- def get_classes(self):
- raise NotImplementedError()
- def get_sample_ids(self):
- raise NotImplementedError()
- class ShuffleWrapDataset(Dataset_):
- def __init__(self, dataset: Dataset, shuffle_seed=100):
- super(ShuffleWrapDataset, self).__init__()
- self.ds = dataset
- ids = self.ds.get_sample_ids()
- sort_idx = np.argsort(np.array(ids))
- assert isinstance(dataset, Dataset)
- self.idx = sort_idx
- if shuffle_seed is not None:
- np.random.seed(shuffle_seed)
- self.shuffled_idx = np.copy(self.idx)
- np.random.shuffle(self.shuffled_idx)
- else:
- self.shuffled_idx = np.copy(self.idx)
- self.idx_map = {k: v for k, v in zip(self.idx, self.shuffled_idx)}
- def train(self, ):
- self.ds.train()
- def eval(self, ):
- self.ds.eval()
- def __getitem__(self, item):
- return self.ds[self.idx_map[self.idx[item]]]
- def __len__(self):
- return len(self.ds)
- def __repr__(self):
- return self.ds.__repr__()
- def has_sample_ids(self):
- return self.ds.has_sample_ids()
- def set_shuffled_idx(self, idx_map: dict):
- self.shuffled_idx = np.array(list(idx_map.values()))
- self.idx_map = idx_map
- def get_sample_ids(self):
- ids = self.ds.get_sample_ids()
- return np.array(ids)[self.shuffled_idx].tolist()
- def get_classes(self):
- return self.ds.get_classes()
- def get_dataset_class(dataset_module_name: str):
- if dataset_module_name.endswith('.py'):
- dataset_module_name = dataset_module_name.replace('.py', '')
- ds_modules = importlib.import_module(
- '{}.dataset.{}'.format(
- ML_PATH, dataset_module_name))
- try:
- for k, v in ds_modules.__dict__.items():
- if isinstance(v, type):
- if issubclass(v, Dataset) and v is not Dataset:
- return v
- raise ValueError('Did not find any class in {}.py that is the subclass of Dataset class'.
- format(dataset_module_name))
- except ValueError as e:
- raise e
|