data_split.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import collections
  17. from sklearn.model_selection import train_test_split
  18. from fate_arch.session import computing_session
  19. from federatedml.model_base import Metric, MetricMeta
  20. from federatedml.feature.binning.base_binning import BaseBinning
  21. from federatedml.model_base import ModelBase
  22. from federatedml.param.data_split_param import DataSplitParam
  23. from federatedml.util import LOGGER
  24. from federatedml.util import data_transform
  25. from federatedml.util.consts import FLOAT_ZERO
  26. ROUND_NUM = 3
  27. class DataSplitter(ModelBase):
  28. def __init__(self):
  29. super().__init__()
  30. self.metric_name = "data_split"
  31. self.metric_namespace = "train"
  32. self.metric_type = "DATA_SPLIT"
  33. self.model_param = DataSplitParam()
  34. self.role = None
  35. self.need_transform = None
  36. def _init_model(self, params):
  37. self.random_state = params.random_state
  38. self.test_size = params.test_size
  39. self.train_size = params.train_size
  40. self.validate_size = params.validate_size
  41. self.stratified = params.stratified
  42. self.shuffle = params.shuffle
  43. self.split_points = params.split_points
  44. if self.split_points:
  45. self.split_points = sorted(self.split_points)
  46. self.need_run = params.need_run
  47. @staticmethod
  48. def _safe_divide(n, d):
  49. result = n / d if d > FLOAT_ZERO else 0.0
  50. if result >= 1:
  51. result = 1.0
  52. return result
  53. def _split(self, ids, y, test_size, train_size):
  54. if test_size <= FLOAT_ZERO:
  55. return ids, [], y, []
  56. if train_size <= FLOAT_ZERO:
  57. return [], ids, [], y
  58. stratify = y if self.stratified else None
  59. if not isinstance(test_size, int):
  60. train_size = round(train_size * len(ids))
  61. test_size = len(ids) - train_size
  62. id_train, id_test, y_train, y_test = train_test_split(ids, y,
  63. test_size=test_size, train_size=train_size,
  64. random_state=self.random_state,
  65. shuffle=self.shuffle, stratify=stratify)
  66. return id_train, id_test, y_train, y_test
  67. def _get_ids(self, data_inst):
  68. ids = sorted([i for i, v in data_inst.mapValues(lambda v: None).collect()])
  69. return ids
  70. def _get_y(self, data_inst):
  71. if self.stratified:
  72. y = [v for i, v in data_inst.mapValues(lambda v: v.label).collect()]
  73. if self.need_transform:
  74. y = self.transform_regression_label(data_inst)
  75. else:
  76. # make dummy y
  77. y = [0] * (data_inst.count())
  78. return y
  79. def check_need_transform(self):
  80. if self.split_points is not None:
  81. if len(self.split_points) == 0:
  82. self.need_transform = False
  83. else:
  84. # only need to produce binned labels if stratified split needed
  85. if self.stratified:
  86. self.need_transform = True
  87. return
  88. @staticmethod
  89. def get_train_test_size(train_size, test_size):
  90. LOGGER.debug(f"original train is {train_size}, original test_size is {test_size}")
  91. # return original set size if int
  92. if isinstance(test_size, int) and isinstance(train_size, int):
  93. return train_size, test_size
  94. total_size = test_size + train_size
  95. new_train_size = DataSplitter._safe_divide(train_size, total_size)
  96. new_test_size = DataSplitter._safe_divide(test_size, total_size)
  97. LOGGER.debug(f"new_train_size is {new_train_size}, new_test_size is {new_test_size}")
  98. return new_train_size, new_test_size
  99. def param_validator(self, data_inst):
  100. """
  101. Validate & transform param inputs
  102. """
  103. # check if need label transform
  104. self.check_need_transform()
  105. # check & transform data set sizes
  106. n_count = data_inst.count()
  107. if isinstance(self.test_size, float) or isinstance(self.train_size, float) or isinstance(self.validate_size,
  108. float):
  109. total_size = 1.0
  110. else:
  111. total_size = n_count
  112. if self.train_size is None:
  113. if self.validate_size is None:
  114. self.train_size = total_size - self.test_size
  115. self.validate_size = total_size - (self.test_size + self.train_size)
  116. else:
  117. if self.test_size is None:
  118. self.test_size = 0
  119. self.train_size = total_size - (self.validate_size + self.test_size)
  120. elif self.test_size is None:
  121. if self.validate_size is None:
  122. self.test_size = total_size - self.train_size
  123. self.validate_size = total_size - (self.test_size + self.train_size)
  124. else:
  125. self.test_size = total_size - (self.validate_size + self.train_size)
  126. elif self.validate_size is None:
  127. if self.train_size is None:
  128. self.train_size = total_size - self.test_size
  129. self.validate_size = total_size - (self.test_size + self.train_size)
  130. if abs((abs(self.train_size) + abs(self.test_size) + abs(self.validate_size)) - total_size) > FLOAT_ZERO:
  131. raise ValueError(f"train_size, test_size, validate_size should sum up to 1.0 or data count")
  132. return
  133. def transform_regression_label(self, data_inst):
  134. edge = self.split_points[-1] + 1
  135. split_points_bin = self.split_points + [edge]
  136. bin_labels = data_inst.mapValues(lambda v: BaseBinning.get_bin_num(v.label, split_points_bin))
  137. binned_y = [v for k, v in bin_labels.collect()]
  138. return binned_y
  139. @staticmethod
  140. def get_class_freq(y, split_points=None, label_names=None):
  141. """
  142. get frequency info of a given y set; only called when stratified is true
  143. :param y: list, y sample
  144. :param split_points: list, split points used to bin regression values
  145. :param label_names: list, label names of all data
  146. :return: dict
  147. """
  148. freq_dict = collections.Counter(y)
  149. freq_keys = freq_dict.keys()
  150. # continuous label
  151. if split_points is not None and len(split_points) > 0:
  152. label_count = len(split_points) + 1
  153. # fill in count for missing bins
  154. if len(freq_keys) < label_count:
  155. for i in range(label_count):
  156. if i not in freq_keys:
  157. freq_dict[i] = 0
  158. # categorical label
  159. else:
  160. if label_names is None:
  161. raise ValueError("No label values collected.")
  162. label_count = len(label_names)
  163. # fill in count for missing labels
  164. if len(freq_keys) < label_count:
  165. for label in label_names:
  166. if label not in freq_keys:
  167. freq_dict[label] = 0
  168. return freq_dict
  169. def callback_count_info(self, id_train, id_validate, id_test, all_metas):
  170. """
  171. Tool to callback returned data count & ratio information
  172. Parameters
  173. ----------
  174. id_train: list or table, id of data set
  175. id_validate: list or table, id of data set
  176. id_test: list or table, id of data set
  177. all_metas: dict, all meta info
  178. Returns
  179. -------
  180. dict
  181. """
  182. metas = {}
  183. if isinstance(id_train, list):
  184. train_count = len(id_train)
  185. validate_count = len(id_validate)
  186. test_count = len(id_test)
  187. else:
  188. train_count = id_train.count()
  189. validate_count = id_validate.count()
  190. test_count = id_test.count()
  191. metas["train"] = train_count
  192. metas["validate"] = validate_count
  193. metas["test"] = test_count
  194. original_count = train_count + validate_count + test_count
  195. metas["original"] = original_count
  196. metric_name = f"{self.metric_name}_count_info"
  197. all_metas[metric_name] = metas
  198. metas = {}
  199. train_ratio = train_count / original_count
  200. validate_ratio = validate_count / original_count
  201. test_ratio = test_count / original_count
  202. metas["train"] = round(train_ratio, ROUND_NUM)
  203. metas["validate"] = round(validate_ratio, ROUND_NUM)
  204. metas["test"] = round(test_ratio, ROUND_NUM)
  205. metric_name = f"{self.metric_name}_ratio_info"
  206. all_metas[metric_name] = metas
  207. # stratified
  208. all_metas["stratified"] = self.stratified
  209. return all_metas
  210. def callback_label_info(self, y_train, y_validate, y_test, all_metas):
  211. """
  212. Tool to callback returned data label information
  213. Parameters
  214. ----------
  215. y_train: list, y
  216. y_validate: list, y
  217. y_test: list, y
  218. all_metas: dict, all meta info
  219. Returns
  220. -------
  221. None
  222. """
  223. metas = {}
  224. y_all = y_train + y_validate + y_test
  225. label_names = None
  226. if self.split_points is None:
  227. label_names = list(set(y_all))
  228. original_freq_dict = DataSplitter.get_class_freq(y_all, self.split_points, label_names)
  229. metas["original"] = original_freq_dict
  230. train_freq_dict = DataSplitter.get_class_freq(y_train, self.split_points, label_names)
  231. metas["train"] = train_freq_dict
  232. validate_freq_dict = DataSplitter.get_class_freq(y_validate, self.split_points, label_names)
  233. metas["validate"] = validate_freq_dict
  234. test_freq_dict = DataSplitter.get_class_freq(y_test, self.split_points, label_names)
  235. metas["test"] = test_freq_dict
  236. if self.split_points is not None and len(self.split_points) > 0:
  237. metas["split_points"] = self.split_points
  238. metas["continuous_label"] = True
  239. else:
  240. metas["label_names"] = label_names
  241. metas["continuous_label"] = False
  242. metric_name = f"{self.metric_name}_label_info"
  243. all_metas[metric_name] = metas
  244. return all_metas
  245. def callback(self, metas):
  246. metric = [Metric(self.metric_name, 0)]
  247. self.callback_metric(metric_name=self.metric_name, metric_namespace=self.metric_namespace, metric_data=metric)
  248. self.tracker.set_metric_meta(metric_name=self.metric_name, metric_namespace=self.metric_namespace,
  249. metric_meta=MetricMeta(name=self.metric_name, metric_type=self.metric_type,
  250. extra_metas=metas))
  251. @staticmethod
  252. def _match_id(data_inst, id_table):
  253. # ids = [(i, None) for i in ids]
  254. # id_table = computing_session.parallelize(ids, include_key=True, partition=data_inst.partitions)
  255. return data_inst.join(id_table, lambda v1, v2: v1)
  256. @staticmethod
  257. def _parallelize_ids(ids, partitions):
  258. ids = [(i, None) for i in ids]
  259. id_table = computing_session.parallelize(ids, include_key=True, partition=partitions)
  260. return id_table
  261. @staticmethod
  262. def _set_output_table_schema(data_inst, schema):
  263. if schema is not None and data_inst.count() > 0:
  264. data_transform.set_schema(data_inst, schema)
  265. def split_data(self, data_inst, id_train, id_validate, id_test):
  266. train_data = DataSplitter._match_id(data_inst, id_train)
  267. validate_data = DataSplitter._match_id(data_inst, id_validate)
  268. test_data = DataSplitter._match_id(data_inst, id_test)
  269. schema = getattr(data_inst, "schema", None)
  270. self._set_output_table_schema(train_data, schema)
  271. self._set_output_table_schema(validate_data, schema)
  272. self._set_output_table_schema(test_data, schema)
  273. return train_data, validate_data, test_data
  274. def fit(self, data_inst):
  275. raise NotImplementedError("fit method in data_split should not be called here.")