table.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import numpy as np
  2. import pandas as pd
  3. from federatedml.statistic.data_overview import with_weight
  4. from federatedml.nn.dataset.base import Dataset
  5. from federatedml.util import LOGGER
  6. class TableDataset(Dataset):
  7. """
  8. A Table Dataset, load data from a give csv path, or transform FATE DTable
  9. Parameters
  10. ----------
  11. label_col str, name of label column in csv, if None, will automatically take 'y' or 'label' or 'target' as label
  12. feature_dtype dtype of feature, supports int, long, float, double
  13. label_dtype: dtype of label, supports int, long, float, double
  14. label_shape: list or tuple, the shape of label
  15. flatten_label: bool, flatten extracted label column or not, default is False
  16. """
  17. def __init__(
  18. self,
  19. label_col=None,
  20. feature_dtype='float',
  21. label_dtype='float',
  22. label_shape=None,
  23. flatten_label=False):
  24. super(TableDataset, self).__init__()
  25. self.with_label = True
  26. self.with_sample_weight = False
  27. self.features: np.ndarray = None
  28. self.label: np.ndarray = None
  29. self.sample_weights: np.ndarray = None
  30. self.origin_table: pd.DataFrame = pd.DataFrame()
  31. self.label_col = label_col
  32. self.f_dtype = self.check_dtype(feature_dtype)
  33. self.l_dtype = self.check_dtype(label_dtype)
  34. if label_shape is not None:
  35. assert isinstance(label_shape, tuple) or isinstance(
  36. label_shape, list), 'label shape is {}'.format(label_shape)
  37. self.label_shape = label_shape
  38. self.flatten_label = flatten_label
  39. # ids, match ids is for FATE match id system
  40. self.sample_ids = None
  41. self.match_ids = None
  42. if self.label_col is not None:
  43. assert isinstance(self.label_col, str) or isinstance(
  44. self.label_col, int), 'label columns parameter must be a str or an int'
  45. @staticmethod
  46. def check_dtype(dtype):
  47. if dtype is not None:
  48. avail = ['long', 'int', 'float', 'double']
  49. assert dtype in avail, 'available dtype is {}, but got {}'.format(
  50. avail, dtype)
  51. if dtype == 'long':
  52. return np.int64
  53. if dtype == 'int':
  54. return np.int32
  55. if dtype == 'float':
  56. return np.float32
  57. if dtype == 'double':
  58. return np.float64
  59. return dtype
  60. def __getitem__(self, item):
  61. if self.with_label:
  62. if self.with_sample_weight and self.training:
  63. return self.features[item], (self.label[item], self.sample_weights[item])
  64. else:
  65. return self.features[item], self.label[item]
  66. else:
  67. return self.features[item]
  68. def __len__(self):
  69. return len(self.origin_table)
  70. def load(self, file_path):
  71. if isinstance(file_path, str):
  72. self.origin_table = pd.read_csv(file_path)
  73. elif isinstance(file_path, pd.DataFrame):
  74. self.origin_table = file_path
  75. else:
  76. # if is FATE DTable, collect data and transform to array format
  77. data_inst = file_path
  78. self.with_sample_weight = with_weight(data_inst)
  79. LOGGER.info('collecting FATE DTable, with sample weight is {}'.format(self.with_sample_weight))
  80. header = data_inst.schema["header"]
  81. LOGGER.debug('input dtable header is {}'.format(header))
  82. data = list(data_inst.collect())
  83. data_keys = [key for (key, val) in data]
  84. data_keys_map = dict(zip(sorted(data_keys), range(len(data_keys))))
  85. keys = [None for idx in range(len(data_keys))]
  86. x_ = [None for idx in range(len(data_keys))]
  87. y_ = [None for idx in range(len(data_keys))]
  88. match_ids = {}
  89. sample_weights = [1 for idx in range(len(data_keys))]
  90. for (key, inst) in data:
  91. idx = data_keys_map[key]
  92. keys[idx] = key
  93. x_[idx] = inst.features
  94. y_[idx] = inst.label
  95. match_ids[key] = inst.inst_id
  96. if self.with_sample_weight:
  97. sample_weights[idx] = inst.weight
  98. x_ = np.asarray(x_)
  99. y_ = np.asarray(y_)
  100. df = pd.DataFrame(x_)
  101. df.columns = header
  102. df['id'] = sorted(data_keys)
  103. df['label'] = y_
  104. # host data has no label, so this columns will all be None
  105. if df['label'].isna().all():
  106. df = df.drop(columns=['label'])
  107. self.origin_table = df
  108. self.sample_weights = np.array(sample_weights)
  109. self.match_ids = match_ids
  110. label_col_candidates = ['y', 'label', 'target']
  111. # automatically set id columns
  112. id_col_candidates = ['id', 'sid']
  113. for id_col in id_col_candidates:
  114. if id_col in self.origin_table:
  115. self.sample_ids = self.origin_table[id_col].values.tolist()
  116. self.origin_table = self.origin_table.drop(columns=[id_col])
  117. break
  118. # infer column name
  119. label = self.label_col
  120. if label is None:
  121. for i in label_col_candidates:
  122. if i in self.origin_table:
  123. label = i
  124. break
  125. if label is None:
  126. self.with_label = False
  127. LOGGER.warning(
  128. 'label default setting is "auto", but found no "y"/"label"/"target" in input table')
  129. else:
  130. if label not in self.origin_table:
  131. raise ValueError(
  132. 'label column {} not found in input table'.format(label))
  133. if self.with_label:
  134. self.label = self.origin_table[label].values
  135. self.features = self.origin_table.drop(columns=[label]).values
  136. if self.l_dtype:
  137. self.label = self.label.astype(self.l_dtype)
  138. if self.label_shape:
  139. self.label = self.label.reshape(self.label_shape)
  140. else:
  141. self.label = self.label.reshape((len(self.features), -1))
  142. if self.flatten_label:
  143. self.label = self.label.flatten()
  144. else:
  145. self.label = None
  146. self.features = self.origin_table.values
  147. if self.f_dtype:
  148. self.features = self.features.astype(self.f_dtype)
  149. def get_classes(self):
  150. if self.label is not None:
  151. return np.unique(self.label).tolist()
  152. else:
  153. raise ValueError(
  154. 'no label found, please check if self.label is set')
  155. def get_sample_ids(self):
  156. return self.sample_ids
  157. def get_match_ids(self):
  158. return self.match_ids