homo_ohe_base.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. # added by jsweng
  19. # base class for OHE alignment
  20. import functools
  21. from federatedml.feature import one_hot_encoder
  22. from federatedml.param.homo_onehot_encoder_param import HomoOneHotParam
  23. from federatedml.transfer_variable.transfer_class.homo_onehot_transfer_variable import HomoOneHotTransferVariable
  24. from federatedml.util import LOGGER
  25. from federatedml.util import consts
  26. class HomoOneHotBase(one_hot_encoder.OneHotEncoder):
  27. def __init__(self):
  28. super(HomoOneHotBase, self).__init__()
  29. self.model_name = 'OHEAlignment'
  30. self.model_param_name = 'OHEAlignmentParam'
  31. self.model_meta_name = 'OHEAlignmentMeta'
  32. self.model_param = HomoOneHotParam()
  33. def _init_model(self, params):
  34. super(HomoOneHotBase, self)._init_model(params)
  35. # self.re_encrypt_batches = params.re_encrypt_batches
  36. self.need_alignment = params.need_alignment
  37. self.transfer_variable = HomoOneHotTransferVariable()
  38. def _init_params(self, data_instances):
  39. if data_instances is None:
  40. return
  41. super(HomoOneHotBase, self)._init_params(data_instances)
  42. def fit(self, data_instances):
  43. """This function allows for one-hot-encoding of the
  44. columns with or without alignment with the other parties
  45. in the federated learning.
  46. Args:
  47. data_instances: data the guest has access to
  48. Returns:
  49. if alignment is on, then the one-hot-encoding data_instances are done with
  50. alignment with parties involved in federated learning else,
  51. the data is one-hot-encoded independently
  52. """
  53. self._init_params(data_instances)
  54. self._abnormal_detection(data_instances)
  55. # keep a copy of original header
  56. ori_header = self.inner_param.header.copy()
  57. # obtain the individual column headers with their values
  58. f1 = functools.partial(self.record_new_header,
  59. inner_param=self.inner_param)
  60. self.col_maps = data_instances.applyPartitions(f1).reduce(self.merge_col_maps)
  61. col_maps = {}
  62. for col_name, pair_obj in self.col_maps.items():
  63. values = [x for x in pair_obj.values]
  64. col_maps[col_name] = values
  65. # LOGGER.debug("new col_maps is: {}".format(col_maps))
  66. if self.need_alignment:
  67. # Send col_maps to arbiter
  68. if self.role == consts.HOST:
  69. self.transfer_variable.host_columns.remote(col_maps, role=consts.ARBITER, idx=-1)
  70. elif self.role == consts.GUEST:
  71. self.transfer_variable.guest_columns.remote(col_maps, role=consts.ARBITER, idx=-1)
  72. # Receive aligned columns from arbiter
  73. aligned_columns = self.transfer_variable.aligned_columns.get(idx=-1)
  74. aligned_col_maps = aligned_columns[0]
  75. # LOGGER.debug("{} aligned columns received are: {}".format(self.role, aligned_col_maps))
  76. self.col_maps = {}
  77. for col_name, value_list in aligned_col_maps.items():
  78. value_set = set([str(x) for x in value_list])
  79. if len(value_set) != len(value_list):
  80. raise ValueError("Same values with different types have occurred among different parties")
  81. transfer_pair = one_hot_encoder.TransferPair(col_name)
  82. for v in value_list:
  83. transfer_pair.add_value(v)
  84. transfer_pair.encode_new_headers()
  85. self.col_maps[col_name] = transfer_pair
  86. self._transform_schema()
  87. data_instances = self.transform(data_instances)
  88. # LOGGER.debug(
  89. # "[Result][OHEAlignment{}] After transform in fit, schema is : {}, header: {}".format(self.role, self.schema,
  90. # self.inner_param.header))
  91. return data_instances