123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import uuid
- from federatedml.secureprotol.symmetric_encryption.cryptor_executor import CryptoExecutor
- from federatedml.secureprotol.symmetric_encryption.pohlig_hellman_encryption import PohligHellmanCipherKey
- from federatedml.statistic.intersect.dh_intersect.dh_intersect_base import DhIntersect
- from federatedml.util import consts, LOGGER
- class DhIntersectionHost(DhIntersect):
- def __init__(self):
- super().__init__()
- self.role = consts.HOST
- self.id_list_local_first = None
- def _sync_commutative_cipher_public_knowledge(self):
- self.commutative_cipher = self.transfer_variable.commutative_cipher_public_knowledge.get(idx=0)
- LOGGER.info(f"got commutative cipher public knowledge from guest")
- def _exchange_id(self, id_cipher, replace_val=True):
- if replace_val:
- id_cipher = id_cipher.mapValues(lambda v: None)
- self.transfer_variable.id_ciphertext_list_exchange_h2g.remote(id_cipher,
- role=consts.GUEST,
- idx=0)
- LOGGER.info("sent id 1st ciphertext to guest")
- id_guest = self.transfer_variable.id_ciphertext_list_exchange_g2h.get(idx=0)
- LOGGER.info("got id 1st ciphertext from guest")
- return id_guest
- def _sync_doubly_encrypted_id_list(self, id_list):
- self.transfer_variable.doubly_encrypted_id_list.remote(id_list,
- role=consts.GUEST,
- idx=0)
- LOGGER.info("sent doubly encrypted id list to guest")
- def get_intersect_ids(self):
- first_cipher_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
- LOGGER.info(f"obtained cipher intersect ids from guest")
- intersect_ids = self.map_encrypt_id_to_raw_id(first_cipher_intersect_ids,
- self.id_list_local_first,
- keep_encrypt_id=False)
- return intersect_ids
- def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
- self._sync_commutative_cipher_public_knowledge()
- self.commutative_cipher.init()
- # 1st ID encrypt: (Eh, (h, Instance))
- self.id_list_local_first = self._encrypt_id(data_instances,
- self.commutative_cipher,
- reserve_original_key=keep_key,
- hash_operator=self.hash_operator,
- salt=self.salt,
- reserve_original_value=keep_key)
- LOGGER.info("encrypted local id for the 1st time")
- # send (Eh, -1), get (Eg, -1)
- id_list_remote_first = self._exchange_id(self.id_list_local_first, keep_key)
- # 2nd ID encrypt & send doubly encrypted guest ID list to guest
- id_list_remote_second = self._encrypt_id(id_list_remote_first,
- self.commutative_cipher,
- reserve_original_key=keep_key) # (EEg, Eg)
- LOGGER.info("encrypted guest id for the 2nd time")
- self._sync_doubly_encrypted_id_list(id_list_remote_second)
- def decrypt_intersect_doubly_encrypted_id(self, id_list_intersect_cipher_cipher=None):
- """
- if self.cardinality_only:
- cardinality = None
- if self.sync_cardinality:
- cardinality = self.transfer_variable.cardinality.get(cardinality, role=consts.GUEST, idx=0)
- LOGGER.info(f"Got intersect cardinality from guest.")
- return cardinality
- """
- intersect_ids = None
- if self.sync_intersect_ids:
- intersect_ids = self.get_intersect_ids()
- return intersect_ids
- def get_intersect_key(self, party_id=None):
- cipher_core = self.commutative_cipher.cipher_core
- intersect_key = {"mod_base": str(cipher_core.mod_base),
- "exponent": str(cipher_core.exponent)}
- return intersect_key
- def load_intersect_key(self, cache_meta):
- intersect_key = cache_meta[str(self.guest_party_id)]["intersect_key"]
- mod_base = int(intersect_key["mod_base"])
- exponent = int(intersect_key["exponent"])
- ph_key = PohligHellmanCipherKey(mod_base, exponent)
- self.commutative_cipher = CryptoExecutor(ph_key)
- def generate_cache(self, data_instances):
- self._sync_commutative_cipher_public_knowledge()
- self.commutative_cipher.init()
- cache_id = str(uuid.uuid4())
- self.cache_id = {self.guest_party_id: cache_id}
- # id_only.schema = cache_schema
- self.cache_transfer_variable.remote(cache_id, role=consts.GUEST, idx=0)
- LOGGER.info(f"remote cache_id to guest")
- # 1st ID encrypt: (Eh, (h, Instance))
- id_list_local_first = self._encrypt_id(data_instances,
- self.commutative_cipher,
- reserve_original_key=True,
- hash_operator=self.hash_operator,
- salt=self.salt,
- reserve_original_value=True)
- LOGGER.info("encrypted local id for the 1st time")
- # cache_schema = {"cache_id": cache_id}
- # id_list_local_first.schema = cache_schema
- id_only = id_list_local_first.mapValues(lambda v: None)
- self.transfer_variable.id_ciphertext_list_exchange_h2g.remote(id_only,
- role=consts.GUEST,
- idx=0)
- LOGGER.info("sent id 1st ciphertext list to guest")
- cache_data = {self.guest_party_id: id_list_local_first}
- cache_meta = {self.guest_party_id: {"cache_id": cache_id,
- "intersect_meta": self.get_intersect_method_meta(),
- "intersect_key": self.get_intersect_key()}}
- return cache_data, cache_meta
- def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_data):
- id_list_remote_first = self.transfer_variable.id_ciphertext_list_exchange_g2h.get(idx=0)
- LOGGER.info("got id 1st ciphertext list from guest")
- # 2nd ID encrypt & send doubly encrypted guest ID list to guest
- id_list_remote_second = self._encrypt_id(id_list_remote_first,
- self.commutative_cipher,
- reserve_original_key=True) # (EEg, Eg)
- LOGGER.info("encrypted guest id for the 2nd time")
- self.id_list_local_first = self.extract_cache_list(cache_data, self.guest_party_id)[0]
- self._sync_doubly_encrypted_id_list(id_list_remote_second)
- def run_cardinality(self, data_instances):
- LOGGER.info(f"run exact_cardinality with DH")
- self.get_intersect_doubly_encrypted_id(data_instances, keep_key=True)
- if self.sync_cardinality:
- self.intersect_num = self.transfer_variable.cardinality.get(idx=0)
- LOGGER.info("Got intersect cardinality from guest.")
|