homo_ohe_arbiter.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. # added by jsweng
  19. # alignment arbiter
  20. from collections import defaultdict
  21. from federatedml.feature.homo_onehot.homo_ohe_base import HomoOneHotBase
  22. from federatedml.util import LOGGER
  23. from federatedml.util import consts
  24. class HomoOneHotArbiter(HomoOneHotBase):
  25. def __init__(self):
  26. super(HomoOneHotArbiter, self).__init__()
  27. def combine_all_column_headers(self, guest_columns, host_columns):
  28. """ This is used when there is a need for alignment within the
  29. federated learning. The function would align the column headers from
  30. guest and host and send the new aligned headers back.
  31. Returns:
  32. Combine all the column headers from guest and host
  33. if there is alignment is used
  34. """
  35. all_cols_dict = defaultdict(set)
  36. # Obtain all the guest headers
  37. for guest_cols in guest_columns:
  38. for k, v in guest_cols.items():
  39. all_cols_dict[k].update(v)
  40. # Obtain all the host headers
  41. for host_cols in host_columns:
  42. for k, v in host_cols.items():
  43. all_cols_dict[k].update(v)
  44. # Align all of them together
  45. combined_all_cols = {}
  46. for el in all_cols_dict.keys():
  47. combined_all_cols[el] = list(all_cols_dict[el])
  48. LOGGER.debug("{} combined cols: {}".format(self.role, combined_all_cols))
  49. return combined_all_cols
  50. def fit(self, data_instances=None):
  51. if self.need_alignment:
  52. guest_columns = self.transfer_variable.guest_columns.get(idx=-1) # getting guest column
  53. host_columns = self.transfer_variable.host_columns.get(idx=-1) # getting host column
  54. combined_all_cols = self.combine_all_column_headers(guest_columns, host_columns)
  55. # Send the aligned headers back to guest and host
  56. self.transfer_variable.aligned_columns.remote(combined_all_cols, role=consts.HOST, idx=-1)
  57. self.transfer_variable.aligned_columns.remote(combined_all_cols, role=consts.GUEST, idx=-1)
  58. def _get_meta(self):
  59. pass
  60. def _get_param(self):
  61. pass
  62. def export_model(self):
  63. return None
  64. def _load_model(self, model_dict):
  65. pass
  66. def transform(self, data_instances):
  67. pass
  68. def load_model(self, model_dict):
  69. pass