dh_intersect_base.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.secureprotol.hash.hash_factory import Hash
  17. from federatedml.secureprotol.symmetric_encryption.cryptor_executor import CryptoExecutor
  18. from federatedml.secureprotol.symmetric_encryption.pohlig_hellman_encryption import PohligHellmanCipherKey
  19. from federatedml.statistic.intersect.base_intersect import Intersect
  20. from federatedml.transfer_variable.transfer_class.dh_intersect_transfer_variable import DhIntersectTransferVariable
  21. from federatedml.util import LOGGER, consts
  22. class DhIntersect(Intersect):
  23. """
  24. adapted from Secure Information Retrieval Module
  25. """
  26. def __init__(self):
  27. super().__init__()
  28. self.role = None
  29. self.transfer_variable = DhIntersectTransferVariable()
  30. self.commutative_cipher = None
  31. def load_params(self, param):
  32. super().load_params(param=param)
  33. self.dh_params = param.dh_params
  34. self.hash_operator = Hash(param.dh_params.hash_method)
  35. self.salt = self.dh_params.salt
  36. self.key_length = self.dh_params.key_length
  37. def get_intersect_method_meta(self):
  38. dh_meta = {"intersect_method": consts.DH,
  39. "hash_method": self.dh_params.hash_method,
  40. "salt": self.salt}
  41. return dh_meta
  42. @staticmethod
  43. def _encrypt_id(data_instances, cipher, reserve_original_key=False, hash_operator=None, salt='',
  44. reserve_original_value=False):
  45. """
  46. Encrypt the key (ID) of input Table
  47. :param cipher: cipher object
  48. :param data_instance: Table
  49. :param reserve_original_key: (enc_key, ori_key) if reserve_original_key == True, otherwise (enc_key, -1)
  50. :param hash_operator: if provided, use map_hash_encrypt
  51. :param salt: if provided, use for map_hash_encrypt
  52. : param reserve_original_value:
  53. (enc_key, (ori_key, val)) for reserve_original_key == True and reserve_original_value==True;
  54. (ori_key, (enc_key, val)) for only reserve_original_value == True.
  55. :return:
  56. """
  57. mode = DhIntersect._get_mode(reserve_original_key, reserve_original_value)
  58. if hash_operator is not None:
  59. return cipher.map_hash_encrypt(data_instances, mode=mode, hash_operator=hash_operator, salt=salt)
  60. return cipher.map_encrypt(data_instances, mode=mode)
  61. @staticmethod
  62. def _get_mode(reserve_original_key=False, reserve_original_value=False):
  63. if reserve_original_key and reserve_original_value:
  64. return 5
  65. if reserve_original_key:
  66. return 4
  67. if reserve_original_value:
  68. return 3
  69. return 1
  70. """
  71. def _generate_commutative_cipher(self):
  72. self.commutative_cipher = [
  73. CryptoExecutor(PohligHellmanCipherKey.generate_key(self.key_length)) for _ in self.host_party_id_list
  74. ]
  75. """
  76. def _generate_commutative_cipher(self):
  77. self.commutative_cipher = CryptoExecutor(PohligHellmanCipherKey.generate_key(self.key_length))
  78. def _sync_commutative_cipher_public_knowledge(self):
  79. """
  80. guest -> host public knowledge
  81. :return:
  82. """
  83. pass
  84. def _exchange_id(self, id_cipher, replace_val=True):
  85. """
  86. :param id_cipher: Table in the form (id, 0)
  87. :return:
  88. """
  89. pass
  90. def _sync_doubly_encrypted_id_list(self, id_list):
  91. """
  92. host -> guest
  93. :param id_list:
  94. :return:
  95. """
  96. pass
  97. def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
  98. raise NotImplementedError("This method should not be called here")
  99. def decrypt_intersect_doubly_encrypted_id(self, id_list_intersect_cipher_cipher):
  100. raise NotImplementedError("This method should not be called here")
  101. def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_set):
  102. raise NotImplementedError("This method should not be called here")
  103. def run_intersect(self, data_instances):
  104. LOGGER.info("Start DH Intersection")
  105. id_list_intersect_cipher_cipher = self.get_intersect_doubly_encrypted_id(data_instances)
  106. intersect_ids = self.decrypt_intersect_doubly_encrypted_id(id_list_intersect_cipher_cipher)
  107. return intersect_ids
  108. def run_cache_intersect(self, data_instances, cache_data):
  109. LOGGER.info("Start DH Intersection with cache")
  110. id_list_intersect_cipher_cipher = self.get_intersect_doubly_encrypted_id_from_cache(data_instances, cache_data)
  111. intersect_ids = self.decrypt_intersect_doubly_encrypted_id(id_list_intersect_cipher_cipher)
  112. return intersect_ids