123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from federatedml.secureprotol.hash.hash_factory import Hash
- from federatedml.secureprotol.symmetric_encryption.cryptor_executor import CryptoExecutor
- from federatedml.secureprotol.symmetric_encryption.pohlig_hellman_encryption import PohligHellmanCipherKey
- from federatedml.statistic.intersect.base_intersect import Intersect
- from federatedml.transfer_variable.transfer_class.dh_intersect_transfer_variable import DhIntersectTransferVariable
- from federatedml.util import LOGGER, consts
- class DhIntersect(Intersect):
- """
- adapted from Secure Information Retrieval Module
- """
- def __init__(self):
- super().__init__()
- self.role = None
- self.transfer_variable = DhIntersectTransferVariable()
- self.commutative_cipher = None
- def load_params(self, param):
- super().load_params(param=param)
- self.dh_params = param.dh_params
- self.hash_operator = Hash(param.dh_params.hash_method)
- self.salt = self.dh_params.salt
- self.key_length = self.dh_params.key_length
- def get_intersect_method_meta(self):
- dh_meta = {"intersect_method": consts.DH,
- "hash_method": self.dh_params.hash_method,
- "salt": self.salt}
- return dh_meta
- @staticmethod
- def _encrypt_id(data_instances, cipher, reserve_original_key=False, hash_operator=None, salt='',
- reserve_original_value=False):
- """
- Encrypt the key (ID) of input Table
- :param cipher: cipher object
- :param data_instance: Table
- :param reserve_original_key: (enc_key, ori_key) if reserve_original_key == True, otherwise (enc_key, -1)
- :param hash_operator: if provided, use map_hash_encrypt
- :param salt: if provided, use for map_hash_encrypt
- : param reserve_original_value:
- (enc_key, (ori_key, val)) for reserve_original_key == True and reserve_original_value==True;
- (ori_key, (enc_key, val)) for only reserve_original_value == True.
- :return:
- """
- mode = DhIntersect._get_mode(reserve_original_key, reserve_original_value)
- if hash_operator is not None:
- return cipher.map_hash_encrypt(data_instances, mode=mode, hash_operator=hash_operator, salt=salt)
- return cipher.map_encrypt(data_instances, mode=mode)
- @staticmethod
- def _get_mode(reserve_original_key=False, reserve_original_value=False):
- if reserve_original_key and reserve_original_value:
- return 5
- if reserve_original_key:
- return 4
- if reserve_original_value:
- return 3
- return 1
- """
- def _generate_commutative_cipher(self):
- self.commutative_cipher = [
- CryptoExecutor(PohligHellmanCipherKey.generate_key(self.key_length)) for _ in self.host_party_id_list
- ]
- """
- def _generate_commutative_cipher(self):
- self.commutative_cipher = CryptoExecutor(PohligHellmanCipherKey.generate_key(self.key_length))
- def _sync_commutative_cipher_public_knowledge(self):
- """
- guest -> host public knowledge
- :return:
- """
- pass
- def _exchange_id(self, id_cipher, replace_val=True):
- """
- :param id_cipher: Table in the form (id, 0)
- :return:
- """
- pass
- def _sync_doubly_encrypted_id_list(self, id_list):
- """
- host -> guest
- :param id_list:
- :return:
- """
- pass
- def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
- raise NotImplementedError("This method should not be called here")
- def decrypt_intersect_doubly_encrypted_id(self, id_list_intersect_cipher_cipher):
- raise NotImplementedError("This method should not be called here")
- def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_set):
- raise NotImplementedError("This method should not be called here")
- def run_intersect(self, data_instances):
- LOGGER.info("Start DH Intersection")
- id_list_intersect_cipher_cipher = self.get_intersect_doubly_encrypted_id(data_instances)
- intersect_ids = self.decrypt_intersect_doubly_encrypted_id(id_list_intersect_cipher_cipher)
- return intersect_ids
- def run_cache_intersect(self, data_instances, cache_data):
- LOGGER.info("Start DH Intersection with cache")
- id_list_intersect_cipher_cipher = self.get_intersect_doubly_encrypted_id_from_cache(data_instances, cache_data)
- intersect_ids = self.decrypt_intersect_doubly_encrypted_id(id_list_intersect_cipher_cipher)
- return intersect_ids
|