|
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import gmpy2
- from federatedml.statistic.intersect.rsa_intersect.rsa_intersect_base import RsaIntersect
- from federatedml.util import consts, LOGGER
- class RsaIntersectionGuest(RsaIntersect):
- def __init__(self):
- super().__init__()
- self.role = consts.GUEST
- def get_host_prvkey_ids(self):
- host_prvkey_ids_list = self.transfer_variable.host_prvkey_ids.get(idx=-1)
- LOGGER.info("Get host_prvkey_ids from all host")
- return host_prvkey_ids_list
- def get_host_filter(self):
- host_filter_list = self.transfer_variable.host_filter.get(idx=-1)
- LOGGER.info("Get host_filter from all host")
- return host_filter_list
- def get_host_pubkey_ids(self):
- host_pubkey_ids_list = self.transfer_variable.host_pubkey_ids.get(idx=-1)
- LOGGER.info("Get host_pubkey_ids from all host")
- return host_pubkey_ids_list
- def sign_host_ids(self, host_pubkey_ids_list):
- # Process(signs) hosts' ids
- guest_sign_host_ids_list = [host_pubkey_ids.map(lambda k, v:
- (k, self.sign_id(k,
- self.d[i],
- self.n[i],
- self.p[i],
- self.q[i],
- self.cp[i],
- self.cq[i])))
- for i, host_pubkey_ids in enumerate(host_pubkey_ids_list)]
- LOGGER.info("Sign host_pubkey_ids with guest prv_keys")
- return guest_sign_host_ids_list
- def send_intersect_ids(self, encrypt_intersect_ids_list, intersect_ids):
- if len(self.host_party_id_list) > 1:
- for i, host_party_id in enumerate(self.host_party_id_list):
- remote_intersect_id = intersect_ids.map(lambda k, v: (v[i], None))
- self.transfer_variable.intersect_ids.remote(remote_intersect_id,
- role=consts.HOST,
- idx=i)
- LOGGER.info(f"Remote intersect ids to Host {host_party_id}!")
- else:
- remote_intersect_id = encrypt_intersect_ids_list[0].mapValues(lambda v: None)
- self.transfer_variable.intersect_ids.remote(remote_intersect_id,
- role=consts.HOST,
- idx=0)
- LOGGER.info(f"Remote intersect ids to Host!")
- def get_host_intersect_ids(self, guest_prvkey_ids_list):
- encrypt_intersect_ids_list = self.transfer_variable.host_intersect_ids.get(idx=-1)
- LOGGER.info("Get intersect ids from Host")
- intersect_ids_pair_list = [self.extract_intersect_ids(ids,
- guest_prvkey_ids_list[i]) for i, ids in
- enumerate(encrypt_intersect_ids_list)]
- intersect_ids = self.filter_intersect_ids(intersect_ids_pair_list, keep_encrypt_ids=True)
- return intersect_ids
- def split_calculation_process(self, data_instances):
- LOGGER.info("RSA intersect using split calculation.")
- # split data
- sid_hash_odd = data_instances.filter(lambda k, v: k & 1)
- sid_hash_even = data_instances.filter(lambda k, v: not k & 1)
- # LOGGER.debug(f"sid_hash_odd count: {sid_hash_odd.count()},"
- # f"odd fraction: {sid_hash_odd.count()/data_instances.count()}")
- # generate pub keys for even ids
- self.generate_protocol_key()
- LOGGER.info("Generate guest protocol key!")
- # send public key e & n to all host
- for i, host_party_id in enumerate(self.host_party_id_list):
- guest_public_key = {"e": self.e[i], "n": self.n[i]}
- self.transfer_variable.guest_pubkey.remote(guest_public_key,
- role=consts.HOST,
- idx=i)
- LOGGER.info(f"Remote public key to Host {host_party_id}.")
- # receive host pub keys for odd ids
- host_public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
- # LOGGER.debug("Get host_public_key:{} from Host".format(host_public_keys))
- LOGGER.info(f"Get host_public_key from Host")
- self.rcv_e = [int(public_key["e"]) for public_key in host_public_keys]
- self.rcv_n = [int(public_key["n"]) for public_key in host_public_keys]
- # encrypt own odd ids with pub keys from host
- pubkey_ids_process_list = [self.pubkey_id_process(sid_hash_odd,
- fraction=self.random_base_fraction,
- random_bit=self.random_bit,
- rsa_e=self.rcv_e[i],
- rsa_n=self.rcv_n[i]) for i in range(len(self.rcv_e))]
- LOGGER.info(f"Perform pubkey_ids_process")
- for i, guest_id in enumerate(pubkey_ids_process_list):
- mask_guest_id = guest_id.mapValues(lambda v: None)
- self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
- role=consts.HOST,
- idx=i)
- LOGGER.info(f"Remote guest_pubkey_ids to Host {i}")
- # encrypt & send prvkey encrypted guest even ids to host
- prvkey_ids_process_pair_list = []
- for i, host_party_id in enumerate(self.host_party_id_list):
- prvkey_ids_process_pair = self.cal_prvkey_ids_process_pair(sid_hash_even,
- self.d[i],
- self.n[i],
- self.p[i],
- self.q[i],
- self.cp[i],
- self.cq[i])
- prvkey_ids_process = prvkey_ids_process_pair.mapValues(lambda v: None)
- self.transfer_variable.guest_prvkey_ids.remote(prvkey_ids_process,
- role=consts.HOST,
- idx=i)
- prvkey_ids_process_pair_list.append(prvkey_ids_process_pair)
- LOGGER.info(f"Remote guest_prvkey_ids to host {host_party_id}")
- # get & sign host pub key encrypted even ids
- host_pubkey_ids_list = self.get_host_pubkey_ids()
- guest_sign_host_ids_list = self.sign_host_ids(host_pubkey_ids_list)
- # send signed host even ids
- for i, host_party_id in enumerate(self.host_party_id_list):
- self.transfer_variable.guest_sign_host_ids.remote(guest_sign_host_ids_list[i],
- role=consts.HOST,
- idx=i)
- LOGGER.info(f"Remote guest_sign_host_ids to Host {host_party_id}.")
- # get prvkey encrypted odd ids from host
- host_prvkey_ids_list = self.get_host_prvkey_ids()
- # Recv host signed odd ids
- # table(guest_pubkey_id, host signed odd ids)
- recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
- LOGGER.info("Get host_sign_guest_ids from Host")
- # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
- # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
- host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
- lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
- int(g[1]),
- self.rcv_n[i]),
- self.final_hash_operator,
- self.rsa_params.salt)))
- for i, v in enumerate(pubkey_ids_process_list)]
- # table(hash(guest_ids_process/r), sid))
- sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
- # get intersect odd ids
- # intersect table(hash(guest_ids_process/r), sid)
- encrypt_intersect_odd_ids_list = [v.join(host_prvkey_ids_list[i], lambda sid, h: sid) for i, v in
- enumerate(sid_host_sign_guest_ids_list)]
- intersect_odd_ids = self.filter_intersect_ids(encrypt_intersect_odd_ids_list, keep_encrypt_ids=True)
- intersect_even_ids = self.get_host_intersect_ids(prvkey_ids_process_pair_list)
- intersect_ids = intersect_odd_ids.union(intersect_even_ids)
- if self.sync_intersect_ids:
- self.send_intersect_ids(encrypt_intersect_odd_ids_list, intersect_odd_ids)
- else:
- LOGGER.info("Skip sync intersect ids with Host(s).")
- return intersect_ids
- def unified_calculation_process(self, data_instances):
- LOGGER.info("RSA intersect using unified calculation.")
- # receives public key e & n
- public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
- # LOGGER.debug(f"Get RSA host_public_key:{public_keys} from Host")
- LOGGER.info(f"Get RSA host_public_key from Host")
- self.rcv_e = [int(public_key["e"]) for public_key in public_keys]
- self.rcv_n = [int(public_key["n"]) for public_key in public_keys]
- pubkey_ids_process_list = [self.pubkey_id_process(data_instances,
- fraction=self.random_base_fraction,
- random_bit=self.random_bit,
- rsa_e=self.rcv_e[i],
- rsa_n=self.rcv_n[i],
- hash_operator=self.first_hash_operator,
- salt=self.salt) for i in range(len(self.rcv_e))]
- LOGGER.info(f"Finish pubkey_ids_process")
- for i, guest_id in enumerate(pubkey_ids_process_list):
- mask_guest_id = guest_id.mapValues(lambda v: None)
- self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
- role=consts.HOST,
- idx=i)
- LOGGER.info("Remote guest_pubkey_ids to Host {}".format(i))
- host_prvkey_ids_list = self.get_host_prvkey_ids()
- LOGGER.info("Get host_prvkey_ids")
- # Recv signed guest ids
- # table(r^e % n *hash(sid), guest_id_process)
- recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
- LOGGER.info("Get host_sign_guest_ids from Host")
- # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
- # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
- host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
- lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
- int(g[1]),
- self.rcv_n[i]),
- self.final_hash_operator,
- self.rsa_params.salt)))
- for i, v in enumerate(pubkey_ids_process_list)]
- # table(hash(guest_ids_process/r), sid))
- sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
- # intersect table(hash(guest_ids_process/r), sid)
- encrypt_intersect_ids_list = [v.join(host_prvkey_ids_list[i], lambda sid, h: sid) for i, v in
- enumerate(sid_host_sign_guest_ids_list)]
- intersect_ids = self.filter_intersect_ids(encrypt_intersect_ids_list, keep_encrypt_ids=True)
- if self.sync_intersect_ids:
- self.send_intersect_ids(encrypt_intersect_ids_list, intersect_ids)
- else:
- LOGGER.info("Skip sync intersect ids with Host(s).")
- return intersect_ids
- def get_intersect_key(self, party_id):
- idx = self.host_party_id_list.index(party_id)
- intersect_key = {"rcv_n": str(self.rcv_n[idx]),
- "rcv_e": str(self.rcv_e[idx])}
- return intersect_key
- def load_intersect_key(self, cache_meta):
- self.rcv_e, self.rcv_n = [], []
- for host_party in self.host_party_id_list:
- intersect_key = cache_meta[str(host_party)]["intersect_key"]
- self.rcv_e.append(int(intersect_key["rcv_e"]))
- self.rcv_n.append(int(intersect_key["rcv_n"]))
- def run_cardinality(self, data_instances):
- LOGGER.info(f"run cardinality_only with RSA")
- # receives public key e & n
- public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
- LOGGER.info(f"Get RSA host_public_key from Host")
- self.rcv_e = [int(public_key["e"]) for public_key in public_keys]
- self.rcv_n = [int(public_key["n"]) for public_key in public_keys]
- pubkey_ids_process_list = [self.pubkey_id_process(data_instances,
- fraction=self.random_base_fraction,
- random_bit=self.random_bit,
- rsa_e=self.rcv_e[i],
- rsa_n=self.rcv_n[i],
- hash_operator=self.first_hash_operator,
- salt=self.salt) for i in range(len(self.rcv_e))]
- LOGGER.info(f"Finish pubkey_ids_process")
- for i, guest_id in enumerate(pubkey_ids_process_list):
- mask_guest_id = guest_id.mapValues(lambda v: None)
- self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
- role=consts.HOST,
- idx=i)
- LOGGER.info("Remote guest_pubkey_ids to Host {}".format(i))
- host_filter_list = self.get_host_filter()
- LOGGER.info("Get host_filter_list")
- # Recv signed guest ids
- # table(r^e % n *hash(sid), guest_id_process)
- recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
- LOGGER.info("Get host_sign_guest_ids from Host")
- # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
- # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
- host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
- lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
- int(g[1]),
- self.rcv_n[i]),
- self.final_hash_operator,
- self.rsa_params.salt)))
- for i, v in enumerate(pubkey_ids_process_list)]
- # table(hash(guest_ids_process/r), sid))
- # sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
- # filter ids
- intersect_ids_list = [host_sign_guest_ids_list[i].filter(lambda k, v: host_filter_list[i].check(v[1]))
- for i in range(len(self.host_party_id_list))]
- intersect_ids_list = [ids.map(lambda k, v: (v[0], None)) for ids in intersect_ids_list]
- intersect_ids = self.get_common_intersection(intersect_ids_list)
- self.intersect_num = intersect_ids.count()
- if self.sync_cardinality:
- self.transfer_variable.cardinality.remote(self.intersect_num, role=consts.HOST, idx=-1)
- LOGGER.info("Sent intersect cardinality to host.")
- else:
- LOGGER.info("Skip sync intersect cardinality with host(s)")
- def generate_cache(self, data_instances):
- LOGGER.info("Run RSA intersect cache")
- # receives public key e & n
- public_keys = self.transfer_variable.host_pubkey.get(idx=-1)
- # LOGGER.debug(f"Get RSA host_public_key:{public_keys} from Host")
- LOGGER.info(f"Get RSA host_public_key from Host")
- self.rcv_e = [int(public_key["e"]) for public_key in public_keys]
- self.rcv_n = [int(public_key["n"]) for public_key in public_keys]
- cache_id_list = self.cache_transfer_variable.get(idx=-1)
- LOGGER.info(f"Get cache_id from all host")
- host_prvkey_ids_list = self.get_host_prvkey_ids()
- LOGGER.info("Get host_prvkey_ids")
- cache_data, cache_meta = {}, {}
- intersect_meta = self.get_intersect_method_meta()
- for i, party_id in enumerate(self.host_party_id_list):
- meta = {"cache_id": cache_id_list[i],
- "intersect_meta": intersect_meta,
- "intersect_key": self.get_intersect_key(party_id)
- }
- cache_meta[party_id] = meta
- cache_data[party_id] = host_prvkey_ids_list[i]
- return cache_data, cache_meta
- def cache_unified_calculation_process(self, data_instances, cache_data):
- LOGGER.info("RSA intersect using cache.")
- pubkey_ids_process_list = [self.pubkey_id_process(data_instances,
- fraction=self.random_base_fraction,
- random_bit=self.random_bit,
- rsa_e=self.rcv_e[i],
- rsa_n=self.rcv_n[i],
- hash_operator=self.first_hash_operator,
- salt=self.salt) for i in range(len(self.rcv_e))]
- LOGGER.info(f"Finish pubkey_ids_process")
- for i, guest_id in enumerate(pubkey_ids_process_list):
- mask_guest_id = guest_id.mapValues(lambda v: None)
- self.transfer_variable.guest_pubkey_ids.remote(mask_guest_id,
- role=consts.HOST,
- idx=i)
- LOGGER.info("Remote guest_pubkey_ids to Host {}".format(i))
- # Recv signed guest ids
- # table(r^e % n *hash(sid), guest_id_process)
- recv_host_sign_guest_ids_list = self.transfer_variable.host_sign_guest_ids.get(idx=-1)
- LOGGER.info("Get host_sign_guest_ids from Host")
- # table(r^e % n *hash(sid), sid, hash(guest_ids_process/r))
- # g[0]=(r^e % n *hash(sid), sid), g[1]=random bits r
- host_sign_guest_ids_list = [v.join(recv_host_sign_guest_ids_list[i],
- lambda g, r: (g[0], RsaIntersectionGuest.hash(gmpy2.divm(int(r),
- int(g[1]),
- self.rcv_n[i]),
- self.final_hash_operator,
- self.rsa_params.salt)))
- for i, v in enumerate(pubkey_ids_process_list)]
- # table(hash(guest_ids_process/r), sid))
- sid_host_sign_guest_ids_list = [g.map(lambda k, v: (v[1], v[0])) for g in host_sign_guest_ids_list]
- # intersect table(hash(guest_ids_process/r), sid)
- host_prvkey_ids_list = self.extract_cache_list(cache_data, self.host_party_id_list)
- encrypt_intersect_ids_list = [v.join(host_prvkey_ids_list[i], lambda sid, h: sid) for i, v in
- enumerate(sid_host_sign_guest_ids_list)]
- intersect_ids = self.filter_intersect_ids(encrypt_intersect_ids_list, keep_encrypt_ids=True)
- if self.sync_intersect_ids:
- self.send_intersect_ids(encrypt_intersect_ids_list, intersect_ids)
- else:
- LOGGER.info("Skip sync intersect ids with Host(s).")
- return intersect_ids
|