homo_data_split.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.model_selection.data_split.data_split import DataSplitter
  17. from federatedml.util import LOGGER
  18. class HomoDataSplitHost(DataSplitter):
  19. def __init__(self):
  20. super().__init__()
  21. def fit(self, data_inst):
  22. LOGGER.debug(f"Enter Hetero {self.role} Data Split fit")
  23. if self.need_run is False:
  24. return
  25. self.param_validator(data_inst)
  26. ids = self._get_ids(data_inst)
  27. y = self._get_y(data_inst)
  28. id_train, id_test_validate, y_train, y_test_validate = self._split(
  29. ids, y, test_size=self.test_size + self.validate_size, train_size=self.train_size)
  30. validate_size, test_size = DataSplitter.get_train_test_size(self.validate_size, self.test_size)
  31. id_validate, id_test, y_validate, y_test = self._split(id_test_validate, y_test_validate,
  32. test_size=test_size, train_size=validate_size)
  33. LOGGER.info(f"Split ids obtained.")
  34. partitions = data_inst.partitions
  35. id_train_table = DataSplitter._parallelize_ids(id_train, partitions)
  36. id_validate_table = DataSplitter._parallelize_ids(id_validate, partitions)
  37. id_test_table = DataSplitter._parallelize_ids(id_test, partitions)
  38. train_data, validate_data, test_data = self.split_data(data_inst,
  39. id_train_table,
  40. id_validate_table,
  41. id_test_table)
  42. LOGGER.info(f"Split data finished.")
  43. all_metas = {}
  44. all_metas = self.callback_count_info(id_train, id_validate, id_test, all_metas)
  45. if self.stratified:
  46. all_metas = self.callback_label_info(y_train, y_validate, y_test, all_metas)
  47. self.callback(all_metas)
  48. self.set_summary(all_metas)
  49. return [train_data, validate_data, test_data]
  50. class HomoDataSplitGuest(DataSplitter):
  51. def __init__(self):
  52. super().__init__()
  53. def fit(self, data_inst):
  54. LOGGER.debug(f"Enter Hetero {self.role} Data Split fit")
  55. if self.need_run is False:
  56. return
  57. self.param_validator(data_inst)
  58. ids = self._get_ids(data_inst)
  59. y = self._get_y(data_inst)
  60. id_train, id_test_validate, y_train, y_test_validate = self._split(
  61. ids, y, test_size=self.test_size + self.validate_size, train_size=self.train_size)
  62. validate_size, test_size = DataSplitter.get_train_test_size(self.validate_size, self.test_size)
  63. id_validate, id_test, y_validate, y_test = self._split(id_test_validate, y_test_validate,
  64. test_size=test_size, train_size=validate_size)
  65. LOGGER.info(f"Split ids obtained.")
  66. partitions = data_inst.partitions
  67. id_train_table = DataSplitter._parallelize_ids(id_train, partitions)
  68. id_validate_table = DataSplitter._parallelize_ids(id_validate, partitions)
  69. id_test_table = DataSplitter._parallelize_ids(id_test, partitions)
  70. train_data, validate_data, test_data = self.split_data(data_inst,
  71. id_train_table,
  72. id_validate_table,
  73. id_test_table)
  74. LOGGER.info(f"Split data finished.")
  75. all_metas = {}
  76. all_metas = self.callback_count_info(id_train, id_validate, id_test, all_metas)
  77. if self.stratified:
  78. all_metas = self.callback_label_info(y_train, y_validate, y_test, all_metas)
  79. self.callback(all_metas)
  80. self.set_summary(all_metas)
  81. LOGGER.info(f"Callback given.")
  82. return [train_data, validate_data, test_data]