hetero_data_split.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.model_selection.data_split.data_split import DataSplitter
  17. from federatedml.transfer_variable.transfer_class.data_split_transfer_variable import \
  18. DataSplitTransferVariable
  19. from federatedml.util import LOGGER
  20. from federatedml.util import consts
  21. class HeteroDataSplitHost(DataSplitter):
  22. def __init__(self):
  23. super().__init__()
  24. self.transfer_variable = DataSplitTransferVariable()
  25. def fit(self, data_inst):
  26. if self.need_run is False:
  27. return
  28. LOGGER.debug(f"Enter Hetero {self.role} Data Split fit")
  29. id_train_table = self.transfer_variable.id_train.get(idx=0)
  30. id_test_table = self.transfer_variable.id_test.get(idx=0)
  31. id_validate_table = self.transfer_variable.id_validate.get(idx=0)
  32. LOGGER.info(f"ids obtained from Guest.")
  33. train_data, validate_data, test_data = self.split_data(data_inst,
  34. id_train_table,
  35. id_validate_table,
  36. id_test_table)
  37. LOGGER.info(f"Split data finished.")
  38. all_metas = {}
  39. all_metas = self.callback_count_info(id_train_table,
  40. id_validate_table,
  41. id_test_table,
  42. all_metas)
  43. self.callback(all_metas)
  44. self.set_summary(all_metas)
  45. LOGGER.info(f"Callback given.")
  46. return [train_data, validate_data, test_data]
  47. class HeteroDataSplitGuest(DataSplitter):
  48. def __init__(self):
  49. super().__init__()
  50. self.transfer_variable = DataSplitTransferVariable()
  51. def fit(self, data_inst):
  52. LOGGER.debug(f"Enter Hetero {self.role} Data Split fit")
  53. if self.need_run is False:
  54. return
  55. self.param_validator(data_inst)
  56. ids = self._get_ids(data_inst)
  57. y = self._get_y(data_inst)
  58. id_train, id_test_validate, y_train, y_test_validate = self._split(
  59. ids, y, test_size=self.test_size + self.validate_size, train_size=self.train_size)
  60. validate_size, test_size = DataSplitter.get_train_test_size(self.validate_size, self.test_size)
  61. id_validate, id_test, y_validate, y_test = self._split(id_test_validate, y_test_validate,
  62. test_size=test_size, train_size=validate_size)
  63. LOGGER.info(f"Split ids obtained.")
  64. partitions = data_inst.partitions
  65. id_train_table = DataSplitter._parallelize_ids(id_train, partitions)
  66. id_validate_table = DataSplitter._parallelize_ids(id_validate, partitions)
  67. id_test_table = DataSplitter._parallelize_ids(id_test, partitions)
  68. self.transfer_variable.id_train.remote(obj=id_train_table, role=consts.HOST, idx=-1)
  69. self.transfer_variable.id_test.remote(obj=id_test_table, role=consts.HOST, idx=-1)
  70. self.transfer_variable.id_validate.remote(obj=id_validate_table, role=consts.HOST, idx=-1)
  71. LOGGER.info(f"ids remote to Host(s)")
  72. train_data, validate_data, test_data = self.split_data(data_inst,
  73. id_train_table,
  74. id_validate_table,
  75. id_test_table)
  76. LOGGER.info(f"Split data finished.")
  77. all_metas = {}
  78. all_metas = self.callback_count_info(id_train, id_validate, id_test, all_metas)
  79. # summary["data_split_count_info"] = all_metas
  80. if self.stratified:
  81. all_metas = self.callback_label_info(y_train, y_validate, y_test, all_metas)
  82. #summary["data_split_label_info"] = all_metas
  83. self.callback(all_metas)
  84. self.set_summary(all_metas)
  85. LOGGER.info(f"Callback given.")
  86. return [train_data, validate_data, test_data]