dh_intersect_host.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import uuid
  17. from federatedml.secureprotol.symmetric_encryption.cryptor_executor import CryptoExecutor
  18. from federatedml.secureprotol.symmetric_encryption.pohlig_hellman_encryption import PohligHellmanCipherKey
  19. from federatedml.statistic.intersect.dh_intersect.dh_intersect_base import DhIntersect
  20. from federatedml.util import consts, LOGGER
  21. class DhIntersectionHost(DhIntersect):
  22. def __init__(self):
  23. super().__init__()
  24. self.role = consts.HOST
  25. self.id_list_local_first = None
  26. def _sync_commutative_cipher_public_knowledge(self):
  27. self.commutative_cipher = self.transfer_variable.commutative_cipher_public_knowledge.get(idx=0)
  28. LOGGER.info(f"got commutative cipher public knowledge from guest")
  29. def _exchange_id(self, id_cipher, replace_val=True):
  30. if replace_val:
  31. id_cipher = id_cipher.mapValues(lambda v: None)
  32. self.transfer_variable.id_ciphertext_list_exchange_h2g.remote(id_cipher,
  33. role=consts.GUEST,
  34. idx=0)
  35. LOGGER.info("sent id 1st ciphertext to guest")
  36. id_guest = self.transfer_variable.id_ciphertext_list_exchange_g2h.get(idx=0)
  37. LOGGER.info("got id 1st ciphertext from guest")
  38. return id_guest
  39. def _sync_doubly_encrypted_id_list(self, id_list):
  40. self.transfer_variable.doubly_encrypted_id_list.remote(id_list,
  41. role=consts.GUEST,
  42. idx=0)
  43. LOGGER.info("sent doubly encrypted id list to guest")
  44. def get_intersect_ids(self):
  45. first_cipher_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
  46. LOGGER.info(f"obtained cipher intersect ids from guest")
  47. intersect_ids = self.map_encrypt_id_to_raw_id(first_cipher_intersect_ids,
  48. self.id_list_local_first,
  49. keep_encrypt_id=False)
  50. return intersect_ids
  51. def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
  52. self._sync_commutative_cipher_public_knowledge()
  53. self.commutative_cipher.init()
  54. # 1st ID encrypt: (Eh, (h, Instance))
  55. self.id_list_local_first = self._encrypt_id(data_instances,
  56. self.commutative_cipher,
  57. reserve_original_key=keep_key,
  58. hash_operator=self.hash_operator,
  59. salt=self.salt,
  60. reserve_original_value=keep_key)
  61. LOGGER.info("encrypted local id for the 1st time")
  62. # send (Eh, -1), get (Eg, -1)
  63. id_list_remote_first = self._exchange_id(self.id_list_local_first, keep_key)
  64. # 2nd ID encrypt & send doubly encrypted guest ID list to guest
  65. id_list_remote_second = self._encrypt_id(id_list_remote_first,
  66. self.commutative_cipher,
  67. reserve_original_key=keep_key) # (EEg, Eg)
  68. LOGGER.info("encrypted guest id for the 2nd time")
  69. self._sync_doubly_encrypted_id_list(id_list_remote_second)
  70. def decrypt_intersect_doubly_encrypted_id(self, id_list_intersect_cipher_cipher=None):
  71. """
  72. if self.cardinality_only:
  73. cardinality = None
  74. if self.sync_cardinality:
  75. cardinality = self.transfer_variable.cardinality.get(cardinality, role=consts.GUEST, idx=0)
  76. LOGGER.info(f"Got intersect cardinality from guest.")
  77. return cardinality
  78. """
  79. intersect_ids = None
  80. if self.sync_intersect_ids:
  81. intersect_ids = self.get_intersect_ids()
  82. return intersect_ids
  83. def get_intersect_key(self, party_id=None):
  84. cipher_core = self.commutative_cipher.cipher_core
  85. intersect_key = {"mod_base": str(cipher_core.mod_base),
  86. "exponent": str(cipher_core.exponent)}
  87. return intersect_key
  88. def load_intersect_key(self, cache_meta):
  89. intersect_key = cache_meta[str(self.guest_party_id)]["intersect_key"]
  90. mod_base = int(intersect_key["mod_base"])
  91. exponent = int(intersect_key["exponent"])
  92. ph_key = PohligHellmanCipherKey(mod_base, exponent)
  93. self.commutative_cipher = CryptoExecutor(ph_key)
  94. def generate_cache(self, data_instances):
  95. self._sync_commutative_cipher_public_knowledge()
  96. self.commutative_cipher.init()
  97. cache_id = str(uuid.uuid4())
  98. self.cache_id = {self.guest_party_id: cache_id}
  99. # id_only.schema = cache_schema
  100. self.cache_transfer_variable.remote(cache_id, role=consts.GUEST, idx=0)
  101. LOGGER.info(f"remote cache_id to guest")
  102. # 1st ID encrypt: (Eh, (h, Instance))
  103. id_list_local_first = self._encrypt_id(data_instances,
  104. self.commutative_cipher,
  105. reserve_original_key=True,
  106. hash_operator=self.hash_operator,
  107. salt=self.salt,
  108. reserve_original_value=True)
  109. LOGGER.info("encrypted local id for the 1st time")
  110. # cache_schema = {"cache_id": cache_id}
  111. # id_list_local_first.schema = cache_schema
  112. id_only = id_list_local_first.mapValues(lambda v: None)
  113. self.transfer_variable.id_ciphertext_list_exchange_h2g.remote(id_only,
  114. role=consts.GUEST,
  115. idx=0)
  116. LOGGER.info("sent id 1st ciphertext list to guest")
  117. cache_data = {self.guest_party_id: id_list_local_first}
  118. cache_meta = {self.guest_party_id: {"cache_id": cache_id,
  119. "intersect_meta": self.get_intersect_method_meta(),
  120. "intersect_key": self.get_intersect_key()}}
  121. return cache_data, cache_meta
  122. def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_data):
  123. id_list_remote_first = self.transfer_variable.id_ciphertext_list_exchange_g2h.get(idx=0)
  124. LOGGER.info("got id 1st ciphertext list from guest")
  125. # 2nd ID encrypt & send doubly encrypted guest ID list to guest
  126. id_list_remote_second = self._encrypt_id(id_list_remote_first,
  127. self.commutative_cipher,
  128. reserve_original_key=True) # (EEg, Eg)
  129. LOGGER.info("encrypted guest id for the 2nd time")
  130. self.id_list_local_first = self.extract_cache_list(cache_data, self.guest_party_id)[0]
  131. self._sync_doubly_encrypted_id_list(id_list_remote_second)
  132. def run_cardinality(self, data_instances):
  133. LOGGER.info(f"run exact_cardinality with DH")
  134. self.get_intersect_doubly_encrypted_id(data_instances, keep_key=True)
  135. if self.sync_cardinality:
  136. self.intersect_num = self.transfer_variable.cardinality.get(idx=0)
  137. LOGGER.info("Got intersect cardinality from guest.")