dh_intersect_guest.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.secureprotol.symmetric_encryption.cryptor_executor import CryptoExecutor
  17. from federatedml.secureprotol.symmetric_encryption.pohlig_hellman_encryption import PohligHellmanCipherKey
  18. from federatedml.statistic.intersect.dh_intersect.dh_intersect_base import DhIntersect
  19. from federatedml.util import consts, LOGGER
  20. class DhIntersectionGuest(DhIntersect):
  21. def __init__(self):
  22. super().__init__()
  23. self.role = consts.GUEST
  24. self.id_list_local_first = None
  25. self.id_local_first = None
  26. self.id_list_remote_second = None
  27. self.id_list_local_second = None
  28. self.host_count = None
  29. # self.recorded_k_data = None
  30. """def _sync_commutative_cipher_public_knowledge(self):
  31. for i, _ in enumerate(self.host_party_id_list):
  32. self.transfer_variable.commutative_cipher_public_knowledge.remote(self.commutative_cipher[i],
  33. role=consts.HOST,
  34. idx=i)
  35. LOGGER.info(f"sent commutative cipher public knowledge to {i}th host")"""
  36. def _sync_commutative_cipher_public_knowledge(self):
  37. self.transfer_variable.commutative_cipher_public_knowledge.remote(self.commutative_cipher,
  38. role=consts.HOST,
  39. idx=-1)
  40. LOGGER.info(f"sent commutative cipher public knowledge to all host")
  41. def _exchange_id(self, id_cipher, replace_val=True):
  42. """for i, id in enumerate(id_list):
  43. if replace_val:
  44. id_only = id.mapValues(lambda v: None)
  45. else:
  46. id_only = id
  47. self.transfer_variable.id_ciphertext_list_exchange_g2h.remote(id_only,
  48. role=consts.HOST,
  49. idx=i)
  50. LOGGER.info(f"sent id 1st ciphertext list to {i} th host")"""
  51. if replace_val:
  52. id_cipher = id_cipher.mapValues(lambda v: None)
  53. self.transfer_variable.id_ciphertext_list_exchange_g2h.remote(id_cipher,
  54. role=consts.HOST,
  55. idx=-1)
  56. LOGGER.info(f"sent id 1st ciphertext to all host")
  57. id_list_remote = self.transfer_variable.id_ciphertext_list_exchange_h2g.get(idx=-1)
  58. LOGGER.info("got id ciphertext list from all host")
  59. return id_list_remote
  60. def _sync_doubly_encrypted_id_list(self, id_list=None):
  61. id_list_guest = self.transfer_variable.doubly_encrypted_id_list.get(idx=-1)
  62. LOGGER.info("got doubly encrypted id list from all host")
  63. return id_list_guest
  64. """
  65. def send_intersect_ids(self, encrypt_intersect_ids_list, intersect_ids):
  66. if len(self.host_party_id_list) > 1:
  67. for i, host_party_id in enumerate(self.host_party_id_list):
  68. remote_intersect_id = intersect_ids.map(lambda k, v: (v[i], None))
  69. self.transfer_variable.intersect_ids.remote(remote_intersect_id,
  70. role=consts.HOST,
  71. idx=i)
  72. LOGGER.info(f"Remote intersect ids to Host {host_party_id}!")
  73. else:
  74. remote_intersect_id = encrypt_intersect_ids_list[0].mapValues(lambda v: None)
  75. self.transfer_variable.intersect_ids.remote(remote_intersect_id,
  76. role=consts.HOST,
  77. idx=0)
  78. LOGGER.info(f"Remote intersect ids to Host!")
  79. """
  80. def send_intersect_ids(self, intersect_ids):
  81. for i, host_party_id in enumerate(self.host_party_id_list):
  82. remote_intersect_id = intersect_ids.map(lambda k, v: (v[i], None))
  83. self.transfer_variable.intersect_ids.remote(remote_intersect_id,
  84. role=consts.HOST,
  85. idx=i)
  86. LOGGER.info(f"Remote intersect ids to {i}th Host {host_party_id}!")
  87. def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
  88. self._generate_commutative_cipher()
  89. self._sync_commutative_cipher_public_knowledge()
  90. self.host_count = len(self.host_party_id_list)
  91. self.commutative_cipher.init()
  92. LOGGER.info("commutative cipher key generated")
  93. # 1st ID encrypt: # (Eg, -1)
  94. self.id_local_first = self._encrypt_id(data_instances,
  95. self.commutative_cipher,
  96. reserve_original_key=keep_key,
  97. hash_operator=self.hash_operator,
  98. salt=self.salt)
  99. LOGGER.info("encrypted guest id for the 1st time")
  100. id_list_remote_first = self._exchange_id(self.id_local_first, keep_key)
  101. # 2nd ID encrypt & receive doubly encrypted ID list: # (EEh, Eh)
  102. self.id_list_remote_second = [self._encrypt_id(id_list_remote_first[i],
  103. self.commutative_cipher,
  104. reserve_original_key=keep_key)
  105. for i in range(self.host_count)]
  106. LOGGER.info("encrypted remote id for the 2nd time")
  107. # receive doubly encrypted ID list from all host:
  108. self.id_list_local_second = self._sync_doubly_encrypted_id_list() # get (EEg, Eg)
  109. # find intersection per host
  110. id_list_intersect_cipher_cipher = [self.extract_intersect_ids(self.id_list_remote_second[i],
  111. self.id_list_local_second[i],
  112. keep_both=keep_key)
  113. for i in range(self.host_count)] # (EEi, [Eh, Eg])
  114. LOGGER.info("encrypted intersection ids found")
  115. return id_list_intersect_cipher_cipher
  116. def decrypt_intersect_doubly_encrypted_id(self, id_list_intersect_cipher_cipher):
  117. # EEi -> (Eg, Eh)
  118. id_list_intersect_cipher = [ids.map(lambda k, v: (v[1], [v[0]])) for ids in id_list_intersect_cipher_cipher]
  119. intersect_ids = self.get_common_intersection(id_list_intersect_cipher, keep_encrypt_ids=True)
  120. LOGGER.info(f"intersection found")
  121. if self.sync_intersect_ids:
  122. self.send_intersect_ids(intersect_ids)
  123. else:
  124. LOGGER.info("Skip sync intersect ids with Host(s).")
  125. intersect_ids = intersect_ids.join(self.id_local_first, lambda cipher, raw: raw)
  126. intersect_ids = intersect_ids.map(lambda k, v: (v, None))
  127. return intersect_ids
  128. def load_intersect_key(self, cache_meta):
  129. host_party = self.host_party_id_list[0]
  130. intersect_key = cache_meta[str(host_party)]["intersect_key"]
  131. mod_base = int(intersect_key["mod_base"])
  132. exponent = int(intersect_key["exponent"])
  133. for host_party in self.host_party_id_list:
  134. cur_intersect_key = cache_meta[str(host_party)]["intersect_key"]
  135. cur_mod_base = int(cur_intersect_key["mod_base"])
  136. cur_exponent = int(cur_intersect_key["exponent"])
  137. if cur_mod_base != mod_base or cur_exponent != exponent:
  138. raise ValueError("Not all intersect keys from cache match, please check.")
  139. ph_key = PohligHellmanCipherKey(mod_base, exponent)
  140. self.commutative_cipher = CryptoExecutor(ph_key)
  141. def generate_cache(self, data_instances):
  142. self._generate_commutative_cipher()
  143. self._sync_commutative_cipher_public_knowledge()
  144. self.commutative_cipher.init()
  145. LOGGER.info("commutative cipher key generated")
  146. cache_id_list = self.cache_transfer_variable.get(idx=-1)
  147. LOGGER.info(f"got cache_id from all host")
  148. id_list_remote_first = self.transfer_variable.id_ciphertext_list_exchange_h2g.get(idx=-1)
  149. LOGGER.info("Get id ciphertext list from all host")
  150. # 2nd ID encrypt & receive doubly encrypted ID list: # (EEh, Eh)
  151. id_list_remote_second = [self._encrypt_id(id_remote_first,
  152. self.commutative_cipher,
  153. reserve_original_key=True)
  154. for id_remote_first in id_list_remote_first]
  155. LOGGER.info("encrypted remote id for the 2nd time")
  156. cache_data, cache_meta = {}, {}
  157. intersect_meta = self.get_intersect_method_meta()
  158. cipher_core = self.commutative_cipher.cipher_core
  159. intersect_key = {"mod_base": str(cipher_core.mod_base),
  160. "exponent": str(cipher_core.exponent)}
  161. for i, party_id in enumerate(self.host_party_id_list):
  162. meta = {"cache_id": cache_id_list[i],
  163. "intersect_meta": intersect_meta,
  164. "intersect_key": intersect_key}
  165. cache_meta[party_id] = meta
  166. cache_data[party_id] = id_list_remote_second[i]
  167. return cache_data, cache_meta
  168. def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_data):
  169. self.id_local_first = self._encrypt_id(data_instances,
  170. self.commutative_cipher,
  171. reserve_original_key=True,
  172. hash_operator=self.hash_operator,
  173. salt=self.salt)
  174. LOGGER.info("encrypted guest id for the 1st time")
  175. id_only = self.id_local_first.mapValues(lambda v: None)
  176. self.transfer_variable.id_ciphertext_list_exchange_g2h.remote(id_only,
  177. role=consts.HOST,
  178. idx=-1)
  179. LOGGER.info(f"sent id 1st ciphertext list to all host")
  180. # receive doubly encrypted ID list from all host:
  181. self.id_list_local_second = self._sync_doubly_encrypted_id_list() # get (EEg, Eg)
  182. self.host_count = len(self.id_list_local_second)
  183. # find intersection per host
  184. cache_list = self.extract_cache_list(cache_data, self.host_party_id_list)
  185. id_list_intersect_cipher_cipher = [self.extract_intersect_ids(cache_list[i],
  186. self.id_list_local_second[i],
  187. keep_both=True)
  188. for i in range(self.host_count)] # (EEi, [Eh, Eg])
  189. LOGGER.info("encrypted intersection ids found")
  190. self.id_list_remote_second = cache_list
  191. return id_list_intersect_cipher_cipher
  192. def run_cardinality(self, data_instances):
  193. LOGGER.info(f"run cardinality_only with DH")
  194. id_list_intersect_cipher_cipher = self.get_intersect_doubly_encrypted_id(data_instances, keep_key=False)
  195. id_intersect_cipher_cipher = self.filter_intersect_ids(id_list_intersect_cipher_cipher,
  196. keep_encrypt_ids=False)
  197. self.intersect_num = id_intersect_cipher_cipher.count()
  198. if self.sync_cardinality:
  199. self.transfer_variable.cardinality.remote(self.intersect_num, role=consts.HOST, idx=-1)
  200. LOGGER.info("Sent intersect cardinality to host.")
  201. else:
  202. LOGGER.info("Skip sync intersect cardinality with host")