123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from federatedml.statistic.intersect.ecdh_intersect.ecdh_intersect_base import EcdhIntersect
- from federatedml.util import consts, LOGGER
- class EcdhIntersectionGuest(EcdhIntersect):
- def __init__(self):
- super().__init__()
- self.role = consts.GUEST
- self.id_local_first = None
- self.id_remote_second = None
- self.id_local_second = None
- self.host_count = None
- def _exchange_id(self, id_cipher, replace_val=True):
- if replace_val:
- id_cipher = id_cipher.mapValues(lambda v: None)
- self.transfer_variable.id_ciphertext_exchange_g2h.remote(id_cipher,
- role=consts.HOST,
- idx=-1)
- LOGGER.info(f"sent id 1st ciphertext to all host")
- id_list_remote = self.transfer_variable.id_ciphertext_exchange_h2g.get(idx=-1)
- LOGGER.info("got id ciphertext from all host")
- return id_list_remote
- def _sync_doubly_encrypted_id(self, id=None):
- id_guest = self.transfer_variable.doubly_encrypted_id.get(idx=-1)
- LOGGER.info("got doubly encrypted id list from host")
- return id_guest
- """
- def send_intersect_ids(self, intersect_ids):
- remote_intersect_id = intersect_ids.map(lambda k, v: (v, None))
- self.transfer_variable.intersect_ids.remote(remote_intersect_id,
- role=consts.HOST,
- idx=0)
- LOGGER.info(f"Remote intersect ids to Host!")
- """
- def send_intersect_ids(self, intersect_ids):
- for i, host_party_id in enumerate(self.host_party_id_list):
- remote_intersect_id = intersect_ids.map(lambda k, v: (v[i], None))
- self.transfer_variable.intersect_ids.remote(remote_intersect_id,
- role=consts.HOST,
- idx=i)
- LOGGER.info(f"Remote intersect ids to {i}th Host {host_party_id}!")
- def get_intersect_doubly_encrypted_id(self, data_instances, keep_key=True):
- self.init_curve()
- LOGGER.info(f"curve instance obtained")
- # 1st ID encrypt: # (Eg, -1)
- self.id_local_first = self._encrypt_id(data_instances,
- self.curve_instance,
- reserve_original_key=keep_key,
- hash_operator=self.hash_operator,
- salt=self.salt)
- LOGGER.info("encrypted guest id for the 1st time")
- id_list_remote_first = self._exchange_id(self.id_local_first, keep_key)
- # 2nd ID encrypt & receive doubly encrypted ID list: # (EEh, Eh)
- self.id_list_remote_second = [self._sign_id(id_remote_first,
- self.curve_instance,
- reserve_original_key=keep_key)
- for id_remote_first in id_list_remote_first]
- LOGGER.info("encrypted remote id for the 2nd time")
- # receive doubly encrypted ID list from all host:
- self.id_list_local_second = self._sync_doubly_encrypted_id() # get (EEg, Eg)
- # find intersection per host: (EEi, [Eg, Eh])
- id_list_intersect_cipher_cipher = [self.extract_intersect_ids(remote_cipher,
- local_cipher,
- keep_both=keep_key)
- for remote_cipher, local_cipher in zip(self.id_list_remote_second,
- self.id_list_local_second)]
- LOGGER.info("encrypted intersection ids found")
- return id_list_intersect_cipher_cipher
- def decrypt_intersect_doubly_encrypted_id(self, id_intersect_cipher_cipher):
- # EEi -> (Eg, Eh)
- id_list_intersect_cipher = [ids.map(lambda k, v: (v[1], [v[0]])) for ids in id_intersect_cipher_cipher]
- intersect_ids = self.get_common_intersection(id_list_intersect_cipher, keep_encrypt_ids=True)
- LOGGER.info(f"intersection found")
- if self.sync_intersect_ids:
- self.send_intersect_ids(intersect_ids)
- else:
- LOGGER.info("Skip sync intersect ids with Host(s).")
- intersect_ids = intersect_ids.join(self.id_local_first, lambda cipher, raw: raw)
- intersect_ids = intersect_ids.map(lambda k, v: (v, None))
- return intersect_ids
- def get_intersect_key(self, party_id):
- intersect_key = {"curve_key": self.curve_instance.get_curve_key().decode("latin1")}
- return intersect_key
- def load_intersect_key(self, cache_meta):
- host_party = self.host_party_id_list[0]
- intersect_key = cache_meta[str(host_party)]["intersect_key"]
- for host_party in self.host_party_id_list:
- cur_intersect_key = cache_meta[str(host_party)]["intersect_key"]
- if cur_intersect_key != cur_intersect_key:
- raise ValueError(f"Not all intersect keys from cache match, please check.")
- curve_key = intersect_key["curve_key"].encode("latin1")
- self.init_curve(curve_key)
- def generate_cache(self, data_instances):
- self.init_curve()
- LOGGER.info(f"curve instance obtained")
- cache_id_list = self.cache_transfer_variable.get(idx=-1)
- LOGGER.info(f"got cache_id from all host")
- id_list_remote_first = self.transfer_variable.id_ciphertext_exchange_h2g.get(idx=-1)
- LOGGER.info("Get id ciphertext list from all host")
- # 2nd ID encrypt & receive doubly encrypted ID list: # (EEh, Eh)
- id_remote_second = [self._sign_id(id_remote_first,
- self.curve_instance,
- reserve_original_key=True)
- for id_remote_first in id_list_remote_first]
- LOGGER.info("encrypted remote id for the 2nd time")
- cache_data, cache_meta = {}, {}
- intersect_meta = self.get_intersect_method_meta()
- for i, party_id in enumerate(self.host_party_id_list):
- meta = {"cache_id": cache_id_list[i],
- "intersect_meta": intersect_meta,
- "intersect_key": self.get_intersect_key(party_id)}
- cache_meta[party_id] = meta
- cache_data[party_id] = id_remote_second[i]
- return cache_data, cache_meta
- def get_intersect_doubly_encrypted_id_from_cache(self, data_instances, cache_data):
- self.id_local_first = self._encrypt_id(data_instances,
- self.curve_instance,
- reserve_original_key=True,
- hash_operator=self.hash_operator,
- salt=self.salt)
- LOGGER.info("encrypted guest id for the 1st time")
- id_only = self.id_local_first.mapValues(lambda v: None)
- self.transfer_variable.id_ciphertext_exchange_g2h.remote(id_only,
- role=consts.HOST,
- idx=-1)
- LOGGER.info(f"sent id 1st ciphertext to host")
- # receive doubly encrypted ID from all hosts:
- self.id_list_local_second = self._sync_doubly_encrypted_id() # get (EEg, Eg)
- self.host_count = len(self.id_list_local_second)
- # find intersection: (EEi, [Eg, Eh])
- cache_host_list = self.extract_cache_list(cache_data, self.host_party_id_list)
- id_list_intersect_cipher_cipher = [self.extract_intersect_ids(cache_host_list[i],
- self.id_list_local_second[i],
- keep_both=True)
- for i in range(self.host_count)]
- LOGGER.info("encrypted intersection ids found")
- self.id_remote_second = cache_host_list
- return id_list_intersect_cipher_cipher
- def run_cardinality(self, data_instances):
- LOGGER.info(f"run cardinality_only with ECDH")
- # EEg, Eg
- id_list_intersect_cipher_cipher = self.get_intersect_doubly_encrypted_id(data_instances,
- keep_key=False)
- # Eg
- id_intersect_cipher_cipher = self.filter_intersect_ids(id_list_intersect_cipher_cipher)
- self.intersect_num = id_intersect_cipher_cipher.count()
- if self.sync_cardinality:
- self.transfer_variable.cardinality.remote(self.intersect_num, role=consts.HOST, idx=-1)
- LOGGER.info("Sent intersect cardinality to host.")
- else:
- LOGGER.info("Skip sync intersect cardinality with host")
|