blocks.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from fate_arch.session import get_parties
  2. from federatedml.transfer_variable.base_transfer_variable import Variable, BaseTransferVariables
  3. from federatedml.util import consts
  4. from federatedml.secureprotol.diffie_hellman import DiffieHellman
  5. from federatedml.secureprotol import PaillierEncrypt
  6. from federatedml.secureprotol.fate_paillier import PaillierPublicKey
  7. from federatedml.secureprotol.encrypt import PadsCipher
  8. from federatedml.util import LOGGER
  9. from typing import Union
  10. import hashlib
  11. """
  12. Base Transfer variable
  13. """
  14. class HomoTransferBase(BaseTransferVariables):
  15. def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
  16. super().__init__()
  17. if prefix is None:
  18. self.prefix = f"{self.__class__.__module__}.{self.__class__.__name__}."
  19. else:
  20. self.prefix = f"{self.__class__.__module__}.{self.__class__.__name__}.{prefix}_"
  21. self.server = server
  22. self.clients = clients
  23. def create_client_to_server_variable(self, name):
  24. name = f"{self.prefix}{name}"
  25. return Variable.get_or_create(name, lambda: Variable(name, self.clients, self.server))
  26. def create_server_to_client_variable(self, name):
  27. name = f"{self.prefix}{name}"
  28. return Variable.get_or_create(name, lambda: Variable(name, self.server, self.clients))
  29. @staticmethod
  30. def get_parties(roles):
  31. return get_parties().roles_to_parties(roles=roles)
  32. @property
  33. def client_parties(self):
  34. return self.get_parties(roles=self.clients)
  35. @property
  36. def server_parties(self):
  37. return self.get_parties(roles=self.server)
  38. """
  39. Client & Server Communication
  40. """
  41. class CommunicatorTransVar(HomoTransferBase):
  42. def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
  43. super().__init__(server=server, clients=clients, prefix=prefix)
  44. self.client_to_server = self.create_client_to_server_variable(name="client_to_server")
  45. self.server_to_client = self.create_server_to_client_variable(name="server_to_client")
  46. class ServerCommunicator(object):
  47. def __init__(self, prefix=None):
  48. self.trans_var = CommunicatorTransVar(prefix=prefix)
  49. self._client_parties = self.trans_var.client_parties
  50. def get_parties(self, party_idx):
  51. if party_idx == -1:
  52. return self._client_parties
  53. if isinstance(party_idx, list):
  54. return [self._client_parties[i] for i in set(party_idx)]
  55. if isinstance(party_idx, int):
  56. return self._client_parties[party_idx]
  57. else:
  58. raise ValueError('illegal party idx {}'.format(party_idx))
  59. def get_obj(self, suffix=tuple(), party_idx=-1):
  60. party = self.get_parties(party_idx)
  61. return self.trans_var.client_to_server.get_parties(parties=party, suffix=suffix)
  62. def broadcast_obj(self, obj, suffix=tuple(), party_idx=-1):
  63. party = self.get_parties(party_idx)
  64. self.trans_var.server_to_client.remote_parties(obj=obj, parties=party, suffix=suffix)
  65. class ClientCommunicator(object):
  66. def __init__(self, prefix=None):
  67. trans_var = CommunicatorTransVar(prefix=prefix)
  68. self.trans_var = trans_var
  69. self._server_parties = trans_var.server_parties
  70. def send_obj(self, obj, suffix=tuple()):
  71. self.trans_var.client_to_server.remote_parties(obj=obj, parties=self._server_parties, suffix=suffix)
  72. def get_obj(self, suffix=tuple()):
  73. return self.trans_var.server_to_client.get_parties(parties=self._server_parties, suffix=suffix)
  74. """
  75. Diffie Hellman Exchange
  76. """
  77. class DHTransVar(HomoTransferBase):
  78. def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
  79. super().__init__(server=server, clients=clients, prefix=prefix)
  80. self.p_power_r = self.create_client_to_server_variable(name="p_power_r")
  81. self.p_power_r_bc = self.create_server_to_client_variable(name="p_power_r_bc")
  82. self.pubkey = self.create_server_to_client_variable(name="pubkey")
  83. class DHServer(object):
  84. def __init__(self, trans_var: DHTransVar = None):
  85. if trans_var is None:
  86. trans_var = DHTransVar()
  87. self._p_power_r = trans_var.p_power_r
  88. self._p_power_r_bc = trans_var.p_power_r_bc
  89. self._pubkey = trans_var.pubkey
  90. self._client_parties = trans_var.client_parties
  91. def key_exchange(self):
  92. p, g = DiffieHellman.key_pair()
  93. self._pubkey.remote_parties(obj=(int(p), int(g)), parties=self._client_parties)
  94. pubkey = dict(self._p_power_r.get_parties(parties=self._client_parties))
  95. self._p_power_r_bc.remote_parties(obj=pubkey, parties=self._client_parties)
  96. class DHClient(object):
  97. def __init__(self, trans_var: DHTransVar = None):
  98. if trans_var is None:
  99. trans_var = DHTransVar()
  100. self._p_power_r = trans_var.p_power_r
  101. self._p_power_r_bc = trans_var.p_power_r_bc
  102. self._pubkey = trans_var.pubkey
  103. self._server_parties = trans_var.server_parties
  104. def key_exchange(self, uuid: str):
  105. p, g = self._pubkey.get_parties(parties=self._server_parties)[0]
  106. r = DiffieHellman.generate_secret(p)
  107. gr = DiffieHellman.encrypt(g, r, p)
  108. self._p_power_r.remote_parties(obj=(uuid, gr), parties=self._server_parties)
  109. cipher_texts = self._p_power_r_bc.get_parties(parties=self._server_parties)[0]
  110. share_secret = {uid: DiffieHellman.decrypt(gr, r, p) for uid, gr in cipher_texts.items() if uid != uuid}
  111. return share_secret
  112. """
  113. UUID
  114. """
  115. class UUIDTransVar(HomoTransferBase):
  116. def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
  117. super().__init__(server=server, clients=clients, prefix=prefix)
  118. self.uuid = self.create_server_to_client_variable(name="uuid")
  119. class UUIDServer(object):
  120. def __init__(self, trans_var: UUIDTransVar = None):
  121. if trans_var is None:
  122. trans_var = UUIDTransVar()
  123. self._uuid_transfer = trans_var.uuid
  124. self._uuid_set = set()
  125. self._ind = -1
  126. self.client_parties = trans_var.client_parties
  127. # noinspection PyUnusedLocal
  128. @staticmethod
  129. def generate_id(ind, *args, **kwargs):
  130. return hashlib.md5(f"{ind}".encode("ascii")).hexdigest()
  131. def _next_uuid(self):
  132. while True:
  133. self._ind += 1
  134. uid = self.generate_id(self._ind)
  135. if uid in self._uuid_set:
  136. continue
  137. self._uuid_set.add(uid)
  138. return uid
  139. def validate_uuid(self):
  140. for party in self.client_parties:
  141. uid = self._next_uuid()
  142. self._uuid_transfer.remote_parties(obj=uid, parties=[party])
  143. class UUIDClient(object):
  144. def __init__(self, trans_var: UUIDTransVar = None):
  145. if trans_var is None:
  146. trans_var = UUIDTransVar()
  147. self._uuid_variable = trans_var.uuid
  148. self._server_parties = trans_var.server_parties
  149. def generate_uuid(self):
  150. uid = self._uuid_variable.get_parties(parties=self._server_parties)[0]
  151. return uid
  152. """
  153. Random Padding
  154. """
  155. class RandomPaddingCipherTransVar(HomoTransferBase):
  156. def __init__(self, server=(consts.ARBITER,), clients=(consts.GUEST, consts.HOST), prefix=None):
  157. super().__init__(server=server, clients=clients, prefix=prefix)
  158. self.uuid_transfer_variable = UUIDTransVar(server=server, clients=clients, prefix=self.prefix)
  159. self.dh_transfer_variable = DHTransVar(server=server, clients=clients, prefix=self.prefix)
  160. class RandomPaddingCipherServer(object):
  161. def __init__(self, trans_var: RandomPaddingCipherTransVar = None):
  162. if trans_var is None:
  163. trans_var = RandomPaddingCipherTransVar()
  164. self._uuid = UUIDServer(trans_var=trans_var.uuid_transfer_variable)
  165. self._dh = DHServer(trans_var=trans_var.dh_transfer_variable)
  166. def exchange_secret_keys(self):
  167. LOGGER.info("synchronizing uuid")
  168. self._uuid.validate_uuid()
  169. LOGGER.info("Diffie-Hellman keys exchanging")
  170. self._dh.key_exchange()
  171. class RandomPaddingCipherClient(object):
  172. def __init__(self, trans_var: RandomPaddingCipherTransVar = None):
  173. if trans_var is None:
  174. trans_var = RandomPaddingCipherTransVar()
  175. self._uuid = UUIDClient(trans_var=trans_var.uuid_transfer_variable)
  176. self._dh = DHClient(trans_var=trans_var.dh_transfer_variable)
  177. self._cipher = None
  178. def create_cipher(self) -> PadsCipher:
  179. LOGGER.info("synchronizing uuid")
  180. uuid = self._uuid.generate_uuid()
  181. LOGGER.info(f"got local uuid")
  182. LOGGER.info("Diffie-Hellman keys exchanging")
  183. exchanged_keys = self._dh.key_exchange(uuid)
  184. LOGGER.info(f"got Diffie-Hellman exchanged keys")
  185. cipher = PadsCipher()
  186. cipher.set_self_uuid(uuid)
  187. cipher.set_exchanged_keys(exchanged_keys)
  188. self._cipher = cipher
  189. return cipher
  190. def encrypt(self, transfer_weights):
  191. return self._cipher.encrypt(transfer_weights)