spdz.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.secureprotol.fate_paillier import PaillierKeypair
  17. from federatedml.secureprotol.spdz.communicator import Communicator
  18. from federatedml.secureprotol.spdz.utils import NamingService
  19. from federatedml.secureprotol.spdz.utils import naming
  20. class SPDZ(object):
  21. __instance = None
  22. @classmethod
  23. def get_instance(cls) -> 'SPDZ':
  24. return cls.__instance
  25. @classmethod
  26. def set_instance(cls, instance):
  27. prev = cls.__instance
  28. cls.__instance = instance
  29. return prev
  30. @classmethod
  31. def has_instance(cls):
  32. return cls.__instance is not None
  33. def __init__(self, name="ss", q_field=None, local_party=None, all_parties=None, use_mix_rand=False, n_length=1024):
  34. self.name_service = naming.NamingService(name)
  35. self._prev_name_service = None
  36. self._pre_instance = None
  37. self.communicator = Communicator(local_party, all_parties)
  38. self.party_idx = self.communicator.party_idx
  39. self.other_parties = self.communicator.other_parties
  40. if len(self.other_parties) > 1:
  41. raise EnvironmentError("support 2-party secret share only")
  42. self.public_key, self.private_key = PaillierKeypair.generate_keypair(n_length=n_length)
  43. if q_field is None:
  44. q_field = self.public_key.n
  45. self.q_field = self._align_q_field(q_field)
  46. self.use_mix_rand = use_mix_rand
  47. def __enter__(self):
  48. self._prev_name_service = NamingService.set_instance(self.name_service)
  49. self._pre_instance = self.set_instance(self)
  50. return self
  51. def __exit__(self, exc_type, exc_val, exc_tb):
  52. NamingService.set_instance(self._pre_instance)
  53. # self.communicator.clean()
  54. def __reduce__(self):
  55. raise PermissionError("it's unsafe to transfer this")
  56. def partial_rescontruct(self):
  57. # todo: partial parties gets rescontructed tensor
  58. pass
  59. @classmethod
  60. def dot(cls, left, right, target_name=None):
  61. return left.dot(right, target_name)
  62. def set_flowid(self, flowid):
  63. self.communicator.set_flowid(flowid)
  64. def _align_q_field(self, q_field):
  65. self.communicator.remote_q_field(q_field=q_field, party=self.other_parties)
  66. other_q_field = self.communicator.get_q_field(party=self.other_parties)
  67. other_q_field.append(q_field)
  68. max_q_field = max(other_q_field)
  69. return max_q_field