123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import uuid
- from federatedml.statistic.intersect.ecdh_intersect.ecdh_intersect_base import EcdhIntersect
- from federatedml.util import consts, LOGGER
- class EcdhIntersectionHost(EcdhIntersect):
- def __init__(self):
- super().__init__()
- self.role = consts.HOST
- self.id_local_first = None
- def _exchange_id(self, id, replace_val=True):
- if replace_val:
- id_only = id.mapValues(lambda v: None)
- else:
- id_only = id
- self.transfer_variable.id_ciphertext_exchange_h2g.remote(id_only,
- role=consts.GUEST,
- idx=0)
- LOGGER.info("sent id 1st ciphertext list to guest")
- id_guest = self.transfer_variable.id_ciphertext_exchange_g2h.get(idx=0)
- LOGGER.info("got id 1st ciphertext list from guest")
- return id_guest
- def _sync_doubly_encrypted_id(self, id):
- self.transfer_variable.doubly_encrypted_id.remote(id,
- role=consts.GUEST,
- idx=0)
- LOGGER.info("sent doubly encrypted id list to guest")
- def get_intersect_ids(self):
- first_cipher_intersect_ids = self.transfer_variable.intersect_ids.get(idx=0)
- LOGGER.info(f"obtained cipher intersect ids from guest")
- intersect_ids = self.map_encrypt_id_to_raw_id(first_cipher_intersect_ids,
- self.id_local_first,
- keep_encrypt_id=False)
- return intersect_ids
- def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
- self.init_curve()
- LOGGER.info(f"curve instance obtained")
- # 1st ID encrypt: (Eh, (h, Instance))
- self.id_local_first = self._encrypt_id(data_instances,
- self.curve_instance,
- reserve_original_key=keep_key,
- hash_operator=self.hash_operator,
- salt=self.salt,
- reserve_original_value=keep_key)
- LOGGER.info("encrypted local id for the 1st time")
- # send (Eh, -1), get (Eg, -1)
- id_remote_first = self._exchange_id(self.id_local_first, keep_key)
- # 2nd ID encrypt & send doubly encrypted guest ID list to guest
- id_remote_second = self._sign_id(id_remote_first,
- self.curve_instance,
- reserve_original_key=keep_key) # (EEg, Eg)
- LOGGER.info("encrypted guest id for the 2nd time")
- self._sync_doubly_encrypted_id(id_remote_second)
- def decrypt_intersect_doubly_encrypted_id(self, id_intersect_cipher_cipher=None):
- intersect_ids = None
- if self.sync_intersect_ids:
- intersect_ids = self.get_intersect_ids()
- return intersect_ids
- def get_intersect_key(self, party_id=None):
- intersect_key = {"curve_key": self.curve_instance.get_curve_key().decode("latin1")}
- return intersect_key
- def load_intersect_key(self, cache_meta):
- intersect_key = cache_meta[str(self.guest_party_id)]["intersect_key"]
- curve_key = intersect_key["curve_key"].encode("latin1")
- self.init_curve(curve_key)
- def generate_cache(self, data_instances):
- self.init_curve()
- LOGGER.info(f"curve instance obtained")
- cache_id = str(uuid.uuid4())
- self.cache_id = {self.guest_party_id: cache_id}
- self.cache_transfer_variable.remote(cache_id, role=consts.GUEST, idx=0)
- LOGGER.info(f"remote cache_id to guest")
- # 1st ID encrypt: (Eh, (h, Instance))
- id_local_first = self._encrypt_id(data_instances,
- self.curve_instance,
- reserve_original_key=True,
- hash_operator=self.hash_operator,
- salt=self.salt,
- reserve_original_value=True)
- LOGGER.info("encrypted local id for the 1st time")
- id_only = id_local_first.mapValues(lambda v: None)
- self.transfer_variable.id_ciphertext_exchange_h2g.remote(id_only,
- role=consts.GUEST,
- idx=0)
- LOGGER.info("sent id 1st ciphertext list to guest")
- cache_data = {self.guest_party_id: id_local_first}
- cache_meta = {self.guest_party_id: {"cache_id": cache_id,
- "intersect_meta": self.get_intersect_method_meta(),
- "intersect_key": self.get_intersect_key()}}
- return cache_data, cache_meta
- def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_data):
- id_remote_first = self.transfer_variable.id_ciphertext_exchange_g2h.get(idx=0)
- LOGGER.info("got id 1st ciphertext from guest")
- # 2nd ID encrypt & send doubly encrypted guest ID to guest
- id_remote_second = self._sign_id(id_remote_first,
- self.curve_instance,
- reserve_original_key=True) # (EEg, Eg)
- LOGGER.info("encrypted guest id for the 2nd time")
- self.id_local_first = self.extract_cache_list(cache_data, self.guest_party_id)[0]
- self._sync_doubly_encrypted_id(id_remote_second)
- def run_cardinality(self, data_instances):
- LOGGER.info(f"run exact_cardinality with DH")
- self.get_intersect_doubly_encrypted_id(data_instances, keep_key=True)
- if self.sync_cardinality:
- self.intersect_num = self.transfer_variable.cardinality.get(idx=0)
- LOGGER.info("Got intersect cardinality from guest.")
|