he.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import numpy as np
  17. from fate_arch.session import is_table
  18. from federatedml.secureprotol.spdz.communicator import Communicator
  19. from federatedml.secureprotol.spdz.utils import rand_tensor, urand_tensor
  20. from federatedml.util import LOGGER
  21. def encrypt_tensor(tensor, public_key):
  22. encrypted_zero = public_key.encrypt(0)
  23. if isinstance(tensor, np.ndarray):
  24. return np.vectorize(lambda e: encrypted_zero + e)(tensor)
  25. elif is_table(tensor):
  26. return tensor.mapValues(lambda x: np.vectorize(lambda e: encrypted_zero + e)(x))
  27. else:
  28. raise NotImplementedError(f"type={type(tensor)}")
  29. def decrypt_tensor(tensor, private_key, otypes):
  30. if isinstance(tensor, np.ndarray):
  31. return np.vectorize(private_key.decrypt, otypes)(tensor)
  32. elif is_table(tensor):
  33. return tensor.mapValues(lambda x: np.vectorize(private_key.decrypt, otypes)(x))
  34. else:
  35. raise NotImplementedError(f"type={type(tensor)}")
  36. def beaver_triplets(a_tensor, b_tensor, dot, q_field, he_key_pair, communicator: Communicator, name):
  37. public_key, private_key = he_key_pair
  38. a = rand_tensor(q_field, a_tensor)
  39. b = rand_tensor(q_field, b_tensor)
  40. def _cross(self_index, other_index):
  41. LOGGER.debug(f"_cross: a={a}, b={b}")
  42. _c = dot(a, b)
  43. encrypted_a = encrypt_tensor(a, public_key)
  44. communicator.remote_encrypted_tensor(encrypted=encrypted_a, tag=f"{name}_a_{self_index}")
  45. r = urand_tensor(q_field, _c)
  46. _p, (ea,) = communicator.get_encrypted_tensors(tag=f"{name}_a_{other_index}")
  47. eab = dot(ea, b)
  48. eab += r
  49. _c -= r
  50. communicator.remote_encrypted_cross_tensor(encrypted=eab,
  51. parties=_p,
  52. tag=f"{name}_cross_a_{other_index}_b_{self_index}")
  53. crosses = communicator.get_encrypted_cross_tensors(tag=f"{name}_cross_a_{self_index}_b_{other_index}")
  54. for eab in crosses:
  55. _c += decrypt_tensor(eab, private_key, [object])
  56. return _c
  57. c = _cross(communicator.party_idx, 1 - communicator.party_idx)
  58. return a, b, c % q_field