|
- from fate_arch.session import get_parties
- from federatedml.transfer_variable.base_transfer_variable import Variable, BaseTransferVariables
- from federatedml.util import consts
- from federatedml.secureprotol.diffie_hellman import DiffieHellman
- from federatedml.secureprotol import PaillierEncrypt
- from federatedml.secureprotol.fate_paillier import PaillierPublicKey
- from federatedml.secureprotol.encrypt import PadsCipher
- from federatedml.util import LOGGER
- from typing import Union
- import hashlib
- """
- Base Transfer variable
- """
- class HomoTransferBase(BaseTransferVariables):
- def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
- super().__init__()
- if prefix is None:
- self.prefix = f"{self.__class__.__module__}.{self.__class__.__name__}."
- else:
- self.prefix = f"{self.__class__.__module__}.{self.__class__.__name__}.{prefix}_"
- self.server = server
- self.clients = clients
- def create_client_to_server_variable(self, name):
- name = f"{self.prefix}{name}"
- return Variable.get_or_create(name, lambda: Variable(name, self.clients, self.server))
- def create_server_to_client_variable(self, name):
- name = f"{self.prefix}{name}"
- return Variable.get_or_create(name, lambda: Variable(name, self.server, self.clients))
- @staticmethod
- def get_parties(roles):
- return get_parties().roles_to_parties(roles=roles)
- @property
- def client_parties(self):
- return self.get_parties(roles=self.clients)
- @property
- def server_parties(self):
- return self.get_parties(roles=self.server)
- """
- Client & Server Communication
- """
- class CommunicatorTransVar(HomoTransferBase):
- def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
- super().__init__(server=server, clients=clients, prefix=prefix)
- self.client_to_server = self.create_client_to_server_variable(name="client_to_server")
- self.server_to_client = self.create_server_to_client_variable(name="server_to_client")
- class ServerCommunicator(object):
- def __init__(self, prefix=None):
- self.trans_var = CommunicatorTransVar(prefix=prefix)
- self._client_parties = self.trans_var.client_parties
- def get_parties(self, party_idx):
- if party_idx == -1:
- return self._client_parties
- if isinstance(party_idx, list):
- return [self._client_parties[i] for i in set(party_idx)]
- if isinstance(party_idx, int):
- return self._client_parties[party_idx]
- else:
- raise ValueError('illegal party idx {}'.format(party_idx))
- def get_obj(self, suffix=tuple(), party_idx=-1):
- party = self.get_parties(party_idx)
- return self.trans_var.client_to_server.get_parties(parties=party, suffix=suffix)
- def broadcast_obj(self, obj, suffix=tuple(), party_idx=-1):
- party = self.get_parties(party_idx)
- self.trans_var.server_to_client.remote_parties(obj=obj, parties=party, suffix=suffix)
- class ClientCommunicator(object):
- def __init__(self, prefix=None):
- trans_var = CommunicatorTransVar(prefix=prefix)
- self.trans_var = trans_var
- self._server_parties = trans_var.server_parties
- def send_obj(self, obj, suffix=tuple()):
- self.trans_var.client_to_server.remote_parties(obj=obj, parties=self._server_parties, suffix=suffix)
- def get_obj(self, suffix=tuple()):
- return self.trans_var.server_to_client.get_parties(parties=self._server_parties, suffix=suffix)
- """
- Diffie Hellman Exchange
- """
- class DHTransVar(HomoTransferBase):
- def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
- super().__init__(server=server, clients=clients, prefix=prefix)
- self.p_power_r = self.create_client_to_server_variable(name="p_power_r")
- self.p_power_r_bc = self.create_server_to_client_variable(name="p_power_r_bc")
- self.pubkey = self.create_server_to_client_variable(name="pubkey")
- class DHServer(object):
- def __init__(self, trans_var: DHTransVar = None):
- if trans_var is None:
- trans_var = DHTransVar()
- self._p_power_r = trans_var.p_power_r
- self._p_power_r_bc = trans_var.p_power_r_bc
- self._pubkey = trans_var.pubkey
- self._client_parties = trans_var.client_parties
- def key_exchange(self):
- p, g = DiffieHellman.key_pair()
- self._pubkey.remote_parties(obj=(int(p), int(g)), parties=self._client_parties)
- pubkey = dict(self._p_power_r.get_parties(parties=self._client_parties))
- self._p_power_r_bc.remote_parties(obj=pubkey, parties=self._client_parties)
- class DHClient(object):
- def __init__(self, trans_var: DHTransVar = None):
- if trans_var is None:
- trans_var = DHTransVar()
- self._p_power_r = trans_var.p_power_r
- self._p_power_r_bc = trans_var.p_power_r_bc
- self._pubkey = trans_var.pubkey
- self._server_parties = trans_var.server_parties
- def key_exchange(self, uuid: str):
- p, g = self._pubkey.get_parties(parties=self._server_parties)[0]
- r = DiffieHellman.generate_secret(p)
- gr = DiffieHellman.encrypt(g, r, p)
- self._p_power_r.remote_parties(obj=(uuid, gr), parties=self._server_parties)
- cipher_texts = self._p_power_r_bc.get_parties(parties=self._server_parties)[0]
- share_secret = {uid: DiffieHellman.decrypt(gr, r, p) for uid, gr in cipher_texts.items() if uid != uuid}
- return share_secret
- """
- UUID
- """
- class UUIDTransVar(HomoTransferBase):
- def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
- super().__init__(server=server, clients=clients, prefix=prefix)
- self.uuid = self.create_server_to_client_variable(name="uuid")
- class UUIDServer(object):
- def __init__(self, trans_var: UUIDTransVar = None):
- if trans_var is None:
- trans_var = UUIDTransVar()
- self._uuid_transfer = trans_var.uuid
- self._uuid_set = set()
- self._ind = -1
- self.client_parties = trans_var.client_parties
- # noinspection PyUnusedLocal
- @staticmethod
- def generate_id(ind, *args, **kwargs):
- return hashlib.md5(f"{ind}".encode("ascii")).hexdigest()
- def _next_uuid(self):
- while True:
- self._ind += 1
- uid = self.generate_id(self._ind)
- if uid in self._uuid_set:
- continue
- self._uuid_set.add(uid)
- return uid
- def validate_uuid(self):
- for party in self.client_parties:
- uid = self._next_uuid()
- self._uuid_transfer.remote_parties(obj=uid, parties=[party])
- class UUIDClient(object):
- def __init__(self, trans_var: UUIDTransVar = None):
- if trans_var is None:
- trans_var = UUIDTransVar()
- self._uuid_variable = trans_var.uuid
- self._server_parties = trans_var.server_parties
- def generate_uuid(self):
- uid = self._uuid_variable.get_parties(parties=self._server_parties)[0]
- return uid
- """
- Random Padding
- """
- class RandomPaddingCipherTransVar(HomoTransferBase):
- def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
- super().__init__(server=server, clients=clients, prefix=prefix)
- self.uuid_transfer_variable = UUIDTransVar(server=server, clients=clients, prefix=self.prefix)
- self.dh_transfer_variable = DHTransVar(server=server, clients=clients, prefix=self.prefix)
- class RandomPaddingCipherServer(object):
- def __init__(self, trans_var: RandomPaddingCipherTransVar = None):
- if trans_var is None:
- trans_var = RandomPaddingCipherTransVar()
- self._uuid = UUIDServer(trans_var=trans_var.uuid_transfer_variable)
- self._dh = DHServer(trans_var=trans_var.dh_transfer_variable)
- def exchange_secret_keys(self):
- LOGGER.info("synchronizing uuid")
- self._uuid.validate_uuid()
- LOGGER.info("Diffie-Hellman keys exchanging")
- self._dh.key_exchange()
- class RandomPaddingCipherClient(object):
- def __init__(self, trans_var: RandomPaddingCipherTransVar = None):
- if trans_var is None:
- trans_var = RandomPaddingCipherTransVar()
- self._uuid = UUIDClient(trans_var=trans_var.uuid_transfer_variable)
- self._dh = DHClient(trans_var=trans_var.dh_transfer_variable)
- self._cipher = None
- def create_cipher(self) -> PadsCipher:
- LOGGER.info("synchronizing uuid")
- uuid = self._uuid.generate_uuid()
- LOGGER.info(f"got local uuid")
- LOGGER.info("Diffie-Hellman keys exchanging")
- exchanged_keys = self._dh.key_exchange(uuid)
- LOGGER.info(f"got Diffie-Hellman exchanged keys")
- cipher = PadsCipher()
- cipher.set_self_uuid(uuid)
- cipher.set_exchanged_keys(exchanged_keys)
- self._cipher = cipher
- return cipher
- def encrypt(self, transfer_weights):
- return self._cipher.encrypt(transfer_weights)
|