ecdh_intersect_guest.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.statistic.intersect.ecdh_intersect.ecdh_intersect_base import EcdhIntersect
  17. from federatedml.util import consts, LOGGER
  18. class EcdhIntersectionGuest(EcdhIntersect):
  19. def __init__(self):
  20. super().__init__()
  21. self.role = consts.GUEST
  22. self.id_local_first = None
  23. self.id_remote_second = None
  24. self.id_local_second = None
  25. self.host_count = None
  26. def _exchange_id(self, id_cipher, replace_val=True):
  27. if replace_val:
  28. id_cipher = id_cipher.mapValues(lambda v: None)
  29. self.transfer_variable.id_ciphertext_exchange_g2h.remote(id_cipher,
  30. role=consts.HOST,
  31. idx=-1)
  32. LOGGER.info(f"sent id 1st ciphertext to all host")
  33. id_list_remote = self.transfer_variable.id_ciphertext_exchange_h2g.get(idx=-1)
  34. LOGGER.info("got id ciphertext from all host")
  35. return id_list_remote
  36. def _sync_doubly_encrypted_id(self, id=None):
  37. id_guest = self.transfer_variable.doubly_encrypted_id.get(idx=-1)
  38. LOGGER.info("got doubly encrypted id list from host")
  39. return id_guest
  40. """
  41. def send_intersect_ids(self, intersect_ids):
  42. remote_intersect_id = intersect_ids.map(lambda k, v: (v, None))
  43. self.transfer_variable.intersect_ids.remote(remote_intersect_id,
  44. role=consts.HOST,
  45. idx=0)
  46. LOGGER.info(f"Remote intersect ids to Host!")
  47. """
  48. def send_intersect_ids(self, intersect_ids):
  49. for i, host_party_id in enumerate(self.host_party_id_list):
  50. remote_intersect_id = intersect_ids.map(lambda k, v: (v[i], None))
  51. self.transfer_variable.intersect_ids.remote(remote_intersect_id,
  52. role=consts.HOST,
  53. idx=i)
  54. LOGGER.info(f"Remote intersect ids to {i}th Host {host_party_id}!")
  55. def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
  56. self.init_curve()
  57. LOGGER.info(f"curve instance obtained")
  58. # 1st ID encrypt: # (Eg, -1)
  59. self.id_local_first = self._encrypt_id(data_instances,
  60. self.curve_instance,
  61. reserve_original_key=keep_key,
  62. hash_operator=self.hash_operator,
  63. salt=self.salt)
  64. LOGGER.info("encrypted guest id for the 1st time")
  65. id_list_remote_first = self._exchange_id(self.id_local_first, keep_key)
  66. # 2nd ID encrypt & receive doubly encrypted ID list: # (EEh, Eh)
  67. self.id_list_remote_second = [self._sign_id(id_remote_first,
  68. self.curve_instance,
  69. reserve_original_key=keep_key)
  70. for id_remote_first in id_list_remote_first]
  71. LOGGER.info("encrypted remote id for the 2nd time")
  72. # receive doubly encrypted ID list from all host:
  73. self.id_list_local_second = self._sync_doubly_encrypted_id() # get (EEg, Eg)
  74. # find intersection per host: (EEi, [Eg, Eh])
  75. id_list_intersect_cipher_cipher = [self.extract_intersect_ids(remote_cipher,
  76. local_cipher,
  77. keep_both=keep_key)
  78. for remote_cipher, local_cipher in zip(self.id_list_remote_second,
  79. self.id_list_local_second)]
  80. LOGGER.info("encrypted intersection ids found")
  81. return id_list_intersect_cipher_cipher
  82. def decrypt_intersect_doubly_encrypted_id(self, id_intersect_cipher_cipher):
  83. # EEi -> (Eg, Eh)
  84. id_list_intersect_cipher = [ids.map(lambda k, v: (v[1], [v[0]])) for ids in id_intersect_cipher_cipher]
  85. intersect_ids = self.get_common_intersection(id_list_intersect_cipher, keep_encrypt_ids=True)
  86. LOGGER.info(f"intersection found")
  87. if self.sync_intersect_ids:
  88. self.send_intersect_ids(intersect_ids)
  89. else:
  90. LOGGER.info("Skip sync intersect ids with Host(s).")
  91. intersect_ids = intersect_ids.join(self.id_local_first, lambda cipher, raw: raw)
  92. intersect_ids = intersect_ids.map(lambda k, v: (v, None))
  93. return intersect_ids
  94. def get_intersect_key(self, party_id):
  95. intersect_key = {"curve_key": self.curve_instance.get_curve_key().decode("latin1")}
  96. return intersect_key
  97. def load_intersect_key(self, cache_meta):
  98. host_party = self.host_party_id_list[0]
  99. intersect_key = cache_meta[str(host_party)]["intersect_key"]
  100. for host_party in self.host_party_id_list:
  101. cur_intersect_key = cache_meta[str(host_party)]["intersect_key"]
  102. if cur_intersect_key != cur_intersect_key:
  103. raise ValueError(f"Not all intersect keys from cache match, please check.")
  104. curve_key = intersect_key["curve_key"].encode("latin1")
  105. self.init_curve(curve_key)
  106. def generate_cache(self, data_instances):
  107. self.init_curve()
  108. LOGGER.info(f"curve instance obtained")
  109. cache_id_list = self.cache_transfer_variable.get(idx=-1)
  110. LOGGER.info(f"got cache_id from all host")
  111. id_list_remote_first = self.transfer_variable.id_ciphertext_exchange_h2g.get(idx=-1)
  112. LOGGER.info("Get id ciphertext list from all host")
  113. # 2nd ID encrypt & receive doubly encrypted ID list: # (EEh, Eh)
  114. id_remote_second = [self._sign_id(id_remote_first,
  115. self.curve_instance,
  116. reserve_original_key=True)
  117. for id_remote_first in id_list_remote_first]
  118. LOGGER.info("encrypted remote id for the 2nd time")
  119. cache_data, cache_meta = {}, {}
  120. intersect_meta = self.get_intersect_method_meta()
  121. for i, party_id in enumerate(self.host_party_id_list):
  122. meta = {"cache_id": cache_id_list[i],
  123. "intersect_meta": intersect_meta,
  124. "intersect_key": self.get_intersect_key(party_id)}
  125. cache_meta[party_id] = meta
  126. cache_data[party_id] = id_remote_second[i]
  127. return cache_data, cache_meta
  128. def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_data):
  129. self.id_local_first = self._encrypt_id(data_instances,
  130. self.curve_instance,
  131. reserve_original_key=True,
  132. hash_operator=self.hash_operator,
  133. salt=self.salt)
  134. LOGGER.info("encrypted guest id for the 1st time")
  135. id_only = self.id_local_first.mapValues(lambda v: None)
  136. self.transfer_variable.id_ciphertext_exchange_g2h.remote(id_only,
  137. role=consts.HOST,
  138. idx=-1)
  139. LOGGER.info(f"sent id 1st ciphertext to host")
  140. # receive doubly encrypted ID from all hosts:
  141. self.id_list_local_second = self._sync_doubly_encrypted_id() # get (EEg, Eg)
  142. self.host_count = len(self.id_list_local_second)
  143. # find intersection: (EEi, [Eg, Eh])
  144. cache_host_list = self.extract_cache_list(cache_data, self.host_party_id_list)
  145. id_list_intersect_cipher_cipher = [self.extract_intersect_ids(cache_host_list[i],
  146. self.id_list_local_second[i],
  147. keep_both=True)
  148. for i in range(self.host_count)]
  149. LOGGER.info("encrypted intersection ids found")
  150. self.id_remote_second = cache_host_list
  151. return id_list_intersect_cipher_cipher
  152. def run_cardinality(self, data_instances):
  153. LOGGER.info(f"run cardinality_only with ECDH")
  154. # EEg, Eg
  155. id_list_intersect_cipher_cipher = self.get_intersect_doubly_encrypted_id(data_instances,
  156. keep_key=False)
  157. # Eg
  158. id_intersect_cipher_cipher = self.filter_intersect_ids(id_list_intersect_cipher_cipher)
  159. self.intersect_num = id_intersect_cipher_cipher.count()
  160. if self.sync_cardinality:
  161. self.transfer_variable.cardinality.remote(self.intersect_num, role=consts.HOST, idx=-1)
  162. LOGGER.info("Sent intersect cardinality to host.")
  163. else:
  164. LOGGER.info("Skip sync intersect cardinality with host")