data.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import numpy as np
  2. from torch.utils.data import Dataset as torchDataset
  3. from federatedml.util import LOGGER
  4. from federatedml.nn.dataset.base import Dataset, get_dataset_class
  5. from federatedml.nn.dataset.image import ImageDataset
  6. from federatedml.nn.dataset.table import TableDataset
  7. def try_dataset_class(dataset_class, path, param):
  8. # try default dataset
  9. try:
  10. dataset_inst: Dataset = dataset_class(**param)
  11. dataset_inst.load(path)
  12. return dataset_inst
  13. except Exception as e:
  14. LOGGER.warning('try to load dataset failed, exception :{}'.format(e))
  15. return None
  16. def load_dataset(dataset_name, data_path_or_dtable, param, dataset_cache: dict):
  17. # load dataset class
  18. if isinstance(data_path_or_dtable, str):
  19. cached_id = data_path_or_dtable
  20. else:
  21. cached_id = str(id(data_path_or_dtable))
  22. if cached_id in dataset_cache:
  23. LOGGER.debug('use cached dataset, cached id {}'.format(cached_id))
  24. return dataset_cache[cached_id]
  25. if dataset_name is None or dataset_name == '':
  26. # automatically match default dataset
  27. LOGGER.info('dataset is not specified, use auto inference')
  28. for ds_class in [TableDataset, ImageDataset]:
  29. dataset_inst = try_dataset_class(
  30. ds_class, data_path_or_dtable, param=param)
  31. if dataset_inst is not None:
  32. break
  33. if dataset_inst is None:
  34. raise ValueError(
  35. 'cannot find default dataset that can successfully load data from path {}, '
  36. 'please check the warning message for error details'. format(data_path_or_dtable))
  37. else:
  38. # load specified dataset
  39. dataset_class = get_dataset_class(dataset_name)
  40. dataset_inst = dataset_class(**param)
  41. dataset_inst.load(data_path_or_dtable)
  42. dataset_cache[cached_id] = dataset_inst
  43. return dataset_inst
  44. def get_ret_predict_table(id_table, pred_table, classes, partitions, computing_session):
  45. id_dtable = computing_session.parallelize(
  46. id_table, partition=partitions, include_key=True)
  47. pred_dtable = computing_session.parallelize(
  48. pred_table, partition=partitions, include_key=True)
  49. return id_dtable, pred_dtable
  50. def add_match_id(id_table: list, dataset_inst: TableDataset):
  51. assert isinstance(dataset_inst, TableDataset), 'when using match id your dataset must be a Table Dataset'
  52. for id_inst in id_table:
  53. id_inst[1].inst_id = dataset_inst.match_ids[id_inst[0]]