selection_info_sync.py 4.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from federatedml.feature.feature_selection.selection_properties import SelectionProperties
  18. from federatedml.transfer_variable.transfer_class.hetero_feature_selection_transfer_variable import \
  19. HeteroFeatureSelectionTransferVariable
  20. from federatedml.statistic.data_overview import look_up_names_from_header
  21. from federatedml.util import LOGGER
  22. from federatedml.util import consts
  23. class Guest(object):
  24. # noinspection PyAttributeOutsideInit
  25. def register_selection_trans_vars(self, transfer_variable):
  26. self._host_select_cols_transfer = transfer_variable.host_select_cols
  27. self._result_left_cols_transfer = transfer_variable.result_left_cols
  28. def sync_select_cols(self, suffix=tuple()):
  29. host_select_col_names = self._host_select_cols_transfer.get(idx=-1, suffix=suffix)
  30. # LOGGER.debug(f"In sync_select_cols, host_names: {host_select_col_names}")
  31. host_selection_params = []
  32. for host_id, select_names in enumerate(host_select_col_names):
  33. host_selection_properties = SelectionProperties()
  34. host_selection_properties.set_header(select_names)
  35. host_selection_properties.set_last_left_col_indexes([x for x in range(len(select_names))])
  36. host_selection_properties.add_select_col_names(select_names)
  37. host_selection_params.append(host_selection_properties)
  38. return host_selection_params
  39. def sync_select_results(self, host_selection_inner_params, suffix=tuple()):
  40. for host_id, host_select_results in enumerate(host_selection_inner_params):
  41. # LOGGER.debug("Send host selected result, left_col_names: {}".format(host_select_results.left_col_names))
  42. self._result_left_cols_transfer.remote(host_select_results.left_col_names,
  43. role=consts.HOST,
  44. idx=host_id,
  45. suffix=suffix)
  46. class Host(object):
  47. # noinspection PyAttributeOutsideInit
  48. def register_selection_trans_vars(self, transfer_variable: HeteroFeatureSelectionTransferVariable):
  49. self._host_select_cols_transfer = transfer_variable.host_select_cols
  50. self._result_left_cols_transfer = transfer_variable.result_left_cols
  51. def sync_select_cols(self, encoded_names, suffix=tuple()):
  52. self._host_select_cols_transfer.remote(encoded_names,
  53. role=consts.GUEST,
  54. idx=0,
  55. suffix=suffix)
  56. def sync_select_results_old(self, selection_param, decode_func=None, suffix=tuple()):
  57. left_cols_names = self._result_left_cols_transfer.get(idx=0, suffix=suffix)
  58. for col_name in left_cols_names:
  59. if decode_func is not None:
  60. col_name = decode_func(col_name)
  61. selection_param.add_left_col_name(col_name)
  62. LOGGER.debug("Received host selected result, original left_cols: {},"
  63. " left_col_names: {}".format(left_cols_names, selection_param.left_col_names))
  64. def sync_select_results(self, selection_param, header=None, anonymous_header=None, suffix=tuple()):
  65. left_col_names = self._result_left_cols_transfer.get(idx=0, suffix=suffix)
  66. if header is not None and anonymous_header is not None:
  67. left_col_plain_names = look_up_names_from_header(left_col_names, anonymous_header, header)
  68. for col_name in left_col_plain_names:
  69. selection_param.add_left_col_name(col_name)
  70. # LOGGER.debug(f"Received host selected result, original left_cols: {left_col_names},"
  71. # f"left_col_names: {selection_param.left_col_names}")