raw_intersect_base.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.secureprotol.hash.hash_factory import Hash
  17. from federatedml.statistic.intersect import Intersect
  18. from federatedml.transfer_variable.transfer_class.raw_intersect_transfer_variable import RawIntersectTransferVariable
  19. from federatedml.util import consts, LOGGER
  20. class RawIntersect(Intersect):
  21. def __init__(self):
  22. super().__init__()
  23. self.role = None
  24. self.transfer_variable = RawIntersectTransferVariable()
  25. self.task_version_id = None
  26. self.tracker = None
  27. def load_params(self, param):
  28. # self.only_output_key = param.only_output_key
  29. # self.sync_intersect_ids = param.sync_intersect_ids
  30. super().load_params(param=param)
  31. self.raw_params = param.raw_params
  32. self.use_hash = self.raw_params.use_hash
  33. self.hash_method = self.raw_params.hash_method
  34. self.base64 = self.raw_params.base64
  35. self.salt = self.raw_params.salt
  36. self.join_role = self.raw_params.join_role
  37. self.hash_operator = Hash(self.hash_method, self.base64)
  38. def intersect_send_id(self, data_instances):
  39. sid_hash_pair = None
  40. if self.use_hash and self.hash_method != "none":
  41. sid_hash_pair = data_instances.map(
  42. lambda k, v: (Intersect.hash(k, self.hash_operator, self.salt), k))
  43. data_sid = sid_hash_pair.mapValues(lambda v: None)
  44. else:
  45. data_sid = data_instances.mapValues(lambda v: None)
  46. LOGGER.info("Send id role is {}".format(self.role))
  47. if self.role == consts.GUEST:
  48. send_ids_federation = self.transfer_variable.send_ids_guest
  49. recv_role = consts.HOST
  50. elif self.role == consts.HOST:
  51. send_ids_federation = self.transfer_variable.send_ids_host
  52. recv_role = consts.GUEST
  53. else:
  54. raise ValueError("Unknown intersect role, please check the code")
  55. send_ids_federation.remote(data_sid,
  56. role=recv_role,
  57. idx=-1)
  58. LOGGER.info("Remote data_sid to role-join")
  59. intersect_ids = None
  60. if self.sync_intersect_ids:
  61. if self.role == consts.HOST:
  62. intersect_ids_federation = self.transfer_variable.intersect_ids_guest
  63. elif self.role == consts.GUEST:
  64. intersect_ids_federation = self.transfer_variable.intersect_ids_host
  65. else:
  66. raise ValueError("Unknown intersect role, please check the code")
  67. recv_intersect_ids_list = intersect_ids_federation.get(idx=-1)
  68. LOGGER.info("Get intersect ids from role-join!")
  69. ids_list_size = len(recv_intersect_ids_list)
  70. LOGGER.info("recv_intersect_ids_list's size is {}".format(ids_list_size))
  71. recv_intersect_ids = self.get_common_intersection(recv_intersect_ids_list)
  72. if self.role == consts.GUEST and len(self.host_party_id_list) > 1:
  73. LOGGER.info(f"raw intersect send role is guest, "
  74. f"and has {self.host_party_id_list} hosts, remote the final intersect_ids to hosts")
  75. self.transfer_variable.sync_intersect_ids_multi_hosts.remote(recv_intersect_ids,
  76. role=consts.HOST,
  77. idx=-1)
  78. if sid_hash_pair and recv_intersect_ids is not None:
  79. hash_intersect_ids_map = recv_intersect_ids.join(sid_hash_pair, lambda r, s: s)
  80. intersect_ids = hash_intersect_ids_map.map(lambda k, v: (v, None))
  81. else:
  82. intersect_ids = recv_intersect_ids
  83. else:
  84. LOGGER.info("Not Get intersect ids from role-join!")
  85. return intersect_ids
  86. def intersect_join_id(self, data_instances):
  87. LOGGER.info("Join id role is {}".format(self.role))
  88. sid_hash_pair = None
  89. if self.use_hash and self.hash_method != "none":
  90. sid_hash_pair = data_instances.map(
  91. lambda k, v: (Intersect.hash(k, self.hash_operator, self.salt), k))
  92. data_sid = sid_hash_pair.mapValues(lambda v: None)
  93. else:
  94. data_sid = data_instances.mapValues(lambda v: None)
  95. if self.role == consts.HOST:
  96. send_ids_federation = self.transfer_variable.send_ids_guest
  97. elif self.role == consts.GUEST:
  98. send_ids_federation = self.transfer_variable.send_ids_host
  99. else:
  100. raise ValueError("Unknown intersect role, please check the code")
  101. recv_ids_list = send_ids_federation.get(idx=-1)
  102. ids_list_size = len(recv_ids_list)
  103. LOGGER.info("Get ids_list from role-send, ids_list size is {}".format(len(recv_ids_list)))
  104. if ids_list_size == 1:
  105. hash_intersect_ids = recv_ids_list[0].join(data_sid, lambda i, d: None)
  106. elif ids_list_size > 1:
  107. hash_intersect_ids_list = []
  108. for ids in recv_ids_list:
  109. hash_intersect_ids_list.append(ids.join(data_sid, lambda i, d: None))
  110. hash_intersect_ids = self.get_common_intersection(hash_intersect_ids_list)
  111. else:
  112. hash_intersect_ids = None
  113. LOGGER.info("Finish intersect_ids computing")
  114. if self.sync_intersect_ids:
  115. if self.role == consts.GUEST:
  116. intersect_ids_federation = self.transfer_variable.intersect_ids_guest
  117. send_role = consts.HOST
  118. elif self.role == consts.HOST:
  119. intersect_ids_federation = self.transfer_variable.intersect_ids_host
  120. send_role = consts.GUEST
  121. else:
  122. raise ValueError("Unknown intersect role, please check the code")
  123. intersect_ids_federation.remote(hash_intersect_ids,
  124. role=send_role,
  125. idx=-1)
  126. LOGGER.info("Remote intersect ids to role-send")
  127. if self.role == consts.HOST and len(self.host_party_id_list) > 1:
  128. LOGGER.info(f"raw intersect join role is host,"
  129. f"and has {self.host_party_id_list} hosts, get the final intersect_ids from guest")
  130. hash_intersect_ids = self.transfer_variable.sync_intersect_ids_multi_hosts.get(idx=0)
  131. if sid_hash_pair:
  132. hash_intersect_ids_map = hash_intersect_ids.join(sid_hash_pair, lambda r, s: s)
  133. intersect_ids = hash_intersect_ids_map.map(lambda k, v: (v, None))
  134. else:
  135. intersect_ids = hash_intersect_ids
  136. return intersect_ids