123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import numpy as np
- from fate_arch.session import is_table
- from federatedml.secureprotol.spdz.communicator import Communicator
- from federatedml.secureprotol.spdz.utils import rand_tensor, urand_tensor
- from federatedml.util import LOGGER
- def encrypt_tensor(tensor, public_key):
- encrypted_zero = public_key.encrypt(0)
- if isinstance(tensor, np.ndarray):
- return np.vectorize(lambda e: encrypted_zero + e)(tensor)
- elif is_table(tensor):
- return tensor.mapValues(lambda x: np.vectorize(lambda e: encrypted_zero + e)(x))
- else:
- raise NotImplementedError(f"type={type(tensor)}")
- def decrypt_tensor(tensor, private_key, otypes):
- if isinstance(tensor, np.ndarray):
- return np.vectorize(private_key.decrypt, otypes)(tensor)
- elif is_table(tensor):
- return tensor.mapValues(lambda x: np.vectorize(private_key.decrypt, otypes)(x))
- else:
- raise NotImplementedError(f"type={type(tensor)}")
- def beaver_triplets(a_tensor, b_tensor, dot, q_field, he_key_pair, communicator: Communicator, name):
- public_key, private_key = he_key_pair
- a = rand_tensor(q_field, a_tensor)
- b = rand_tensor(q_field, b_tensor)
- def _cross(self_index, other_index):
- LOGGER.debug(f"_cross: a={a}, b={b}")
- _c = dot(a, b)
- encrypted_a = encrypt_tensor(a, public_key)
- communicator.remote_encrypted_tensor(encrypted=encrypted_a, tag=f"{name}_a_{self_index}")
- r = urand_tensor(q_field, _c)
- _p, (ea,) = communicator.get_encrypted_tensors(tag=f"{name}_a_{other_index}")
- eab = dot(ea, b)
- eab += r
- _c -= r
- communicator.remote_encrypted_cross_tensor(encrypted=eab,
- parties=_p,
- tag=f"{name}_cross_a_{other_index}_b_{self_index}")
- crosses = communicator.get_encrypted_cross_tensors(tag=f"{name}_cross_a_{self_index}_b_{other_index}")
- for eab in crosses:
- _c += decrypt_tensor(eab, private_key, [object])
- return _c
- c = _cross(communicator.party_idx, 1 - communicator.party_idx)
- return a, b, c % q_field
|