123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import functools
- import numpy as np
- from fate_arch.common import Party
- from fate_arch.computing import is_table
- from federatedml.secureprotol.spdz.beaver_triples import beaver_triplets
- from federatedml.secureprotol.spdz.tensor import fixedpoint_table
- from federatedml.secureprotol.spdz.tensor.base import TensorBase
- from federatedml.secureprotol.spdz.utils import urand_tensor
- # from federatedml.secureprotol.spdz.tensor.fixedpoint_endec import FixedPointEndec
- from federatedml.secureprotol.fixedpoint import FixedPointEndec
- from federatedml.util import LOGGER
- class FixedPointTensor(TensorBase):
- __array_ufunc__ = None
- def __init__(self, value, q_field, endec, tensor_name: str = None):
- super().__init__(q_field, tensor_name)
- self.endec = endec
- self.value = value
- @property
- def shape(self):
- return self.value.shape
- def reshape(self, shape):
- return self._boxed(self.value.reshape(shape))
- def dot(self, other, target_name=None):
- return self.einsum(other, "ij,ik->jk", target_name)
- def dot_local(self, other, target_name=None):
- if isinstance(other, FixedPointTensor):
- other = other.value
- ret = np.dot(self.value, other) % self.q_field
- ret = self.endec.truncate(ret, self.get_spdz().party_idx)
- if not isinstance(ret, np.ndarray):
- ret = np.array([ret])
- return self._boxed(ret, target_name)
- def sub_matrix(self, tensor_name: str, row_indices=None, col_indices=None, rm_row_indices=None,
- rm_col_indices=None):
- if row_indices is not None:
- x_indices = list(row_indices)
- elif row_indices is None and rm_row_indices is not None:
- x_indices = [i for i in range(self.value.shape[0]) if i not in rm_row_indices]
- else:
- raise RuntimeError(f"invalid argument")
- if col_indices is not None:
- y_indices = list(col_indices)
- elif row_indices is None and rm_col_indices is not None:
- y_indices = [i for i in range(self.value.shape[0]) if i not in rm_col_indices]
- else:
- raise RuntimeError(f"invalid argument")
- value = self.value[x_indices, :][:, y_indices]
- return FixedPointTensor(value=value, q_field=self.q_field, endec=self.endec, tensor_name=tensor_name)
- @classmethod
- def from_source(cls, tensor_name, source, **kwargs):
- spdz = cls.get_spdz()
- q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
- if 'encoder' in kwargs:
- encoder = kwargs['encoder']
- else:
- base = kwargs['base'] if 'base' in kwargs else 10
- frac = kwargs['frac'] if 'frac' in kwargs else 4
- encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
- if isinstance(source, np.ndarray):
- source = encoder.encode(source)
- _pre = urand_tensor(q_field, source)
- spdz.communicator.remote_share(share=_pre, tensor_name=tensor_name, party=spdz.other_parties[0])
- for _party in spdz.other_parties[1:]:
- r = urand_tensor(q_field, source)
- spdz.communicator.remote_share(share=(r - _pre) % q_field, tensor_name=tensor_name, party=_party)
- _pre = r
- share = (source - _pre) % q_field
- elif isinstance(source, Party):
- share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
- else:
- raise ValueError(f"type={type(source)}")
- return FixedPointTensor(share, q_field, encoder, tensor_name)
- def einsum(self, other: 'FixedPointTensor', einsum_expr, target_name=None):
- spdz = self.get_spdz()
- target_name = target_name or spdz.name_service.next()
- def _dot_func(_x, _y):
- ret = np.dot(_x, _y)
- if not isinstance(ret, np.ndarray):
- ret = np.array([ret])
- return ret
- # return np.einsum(einsum_expr, _x, _y, optimize=True)
- a, b, c = beaver_triplets(a_tensor=self.value, b_tensor=other.value, dot=_dot_func,
- q_field=self.q_field, he_key_pair=(spdz.public_key, spdz.private_key),
- communicator=spdz.communicator, name=target_name)
- x_add_a = self._raw_add(a).reconstruct(f"{target_name}_confuse_x")
- y_add_b = other._raw_add(b).reconstruct(f"{target_name}_confuse_y")
- cross = c - _dot_func(a, y_add_b) - _dot_func(x_add_a, b)
- if spdz.party_idx == 0:
- cross += _dot_func(x_add_a, y_add_b)
- cross = cross % self.q_field
- cross = self.endec.truncate(cross, self.get_spdz().party_idx)
- share = self._boxed(cross, tensor_name=target_name)
- return share
- def get(self, tensor_name=None, broadcast=True):
- return self.endec.decode(self.reconstruct(tensor_name, broadcast))
- def reconstruct(self, tensor_name=None, broadcast=True):
- from federatedml.secureprotol.spdz import SPDZ
- spdz = SPDZ.get_instance()
- share_val = self.value.copy()
- name = tensor_name or self.tensor_name
- if name is None:
- raise ValueError("name not specified")
- # remote share to other parties
- if broadcast:
- spdz.communicator.broadcast_rescontruct_share(share_val, name)
- # get shares from other parties
- for other_share in spdz.communicator.get_rescontruct_shares(name):
- # LOGGER.debug(f"share_val: {share_val}, other_share: {other_share}")
- share_val += other_share
- try:
- share_val %= self.q_field
- return share_val
- except BaseException:
- return share_val
- def transpose(self):
- value = self.value.transpose()
- return self._boxed(value)
- def broadcast_reconstruct_share(self, tensor_name=None):
- from federatedml.secureprotol.spdz import SPDZ
- spdz = SPDZ.get_instance()
- share_val = self.value.copy()
- name = tensor_name or self.tensor_name
- if name is None:
- raise ValueError("name not specified")
- # remote share to other parties
- spdz.communicator.broadcast_rescontruct_share(share_val, name)
- return share_val
- def _boxed(self, value, tensor_name=None):
- return FixedPointTensor(value=value, q_field=self.q_field, endec=self.endec, tensor_name=tensor_name)
- def __str__(self):
- return f"tensor_name={self.tensor_name}, value={self.value}"
- def __repr__(self):
- return self.__str__()
- def as_name(self, tensor_name):
- return self._boxed(value=self.value, tensor_name=tensor_name)
- def _raw_add(self, other):
- z_value = (self.value + other) % self.q_field
- return self._boxed(z_value)
- def _raw_sub(self, other):
- z_value = (self.value - other) % self.q_field
- return self._boxed(z_value)
- def __add__(self, other):
- if isinstance(other, PaillierFixedPointTensor):
- z_value = (self.value + other)
- return PaillierFixedPointTensor(z_value)
- elif isinstance(other, FixedPointTensor):
- return self._raw_add(other.value)
- z_value = (self.value + other) % self.q_field
- return self._boxed(z_value)
- def __radd__(self, other):
- return self.__add__(other)
- def __sub__(self, other):
- if isinstance(other, PaillierFixedPointTensor):
- z_value = (self.value - other)
- return PaillierFixedPointTensor(z_value)
- elif isinstance(other, FixedPointTensor):
- return self._raw_sub(other.value)
- z_value = (self.value - other) % self.q_field
- return self._boxed(z_value)
- def __rsub__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- return other - self
- z_value = (other - self.value) % self.q_field
- return self._boxed(z_value)
- def __mul__(self, other):
- if isinstance(other, PaillierFixedPointTensor):
- z_value = self.value * other.value
- return PaillierFixedPointTensor(z_value)
- if isinstance(other, FixedPointTensor):
- other = other.value
- z_value = self.value * other
- z_value = z_value % self.q_field
- z_value = self.endec.truncate(z_value, self.get_spdz().party_idx)
- return self._boxed(z_value)
- def __rmul__(self, other):
- return self.__mul__(other)
- def __matmul__(self, other):
- return self.einsum(other, "ij,jk->ik")
- class PaillierFixedPointTensor(TensorBase):
- __array_ufunc__ = None
- def __init__(self, value, tensor_name: str = None, cipher=None):
- super().__init__(q_field=None, tensor_name=tensor_name)
- self.value = value
- self.cipher = cipher
- def dot(self, other, target_name=None):
- def _vec_dot(x, y):
- ret = np.dot(x, y)
- if not isinstance(ret, np.ndarray):
- ret = np.array([ret])
- return ret
- if isinstance(other, (FixedPointTensor, fixedpoint_table.FixedPointTensor)):
- other = other.value
- if isinstance(other, np.ndarray):
- ret = _vec_dot(self.value, other)
- return self._boxed(ret, target_name)
- elif is_table(other):
- f = functools.partial(_vec_dot,
- self.value)
- ret = other.mapValues(f)
- return fixedpoint_table.PaillierFixedPointTensor(value=ret,
- tensor_name=target_name,
- cipher=self.cipher)
- else:
- raise ValueError(f"type={type(other)}")
- def broadcast_reconstruct_share(self, tensor_name=None):
- from federatedml.secureprotol.spdz import SPDZ
- spdz = SPDZ.get_instance()
- share_val = self.value.copy()
- name = tensor_name or self.tensor_name
- if name is None:
- raise ValueError("name not specified")
- # remote share to other parties
- spdz.communicator.broadcast_rescontruct_share(share_val, name)
- return share_val
- def __str__(self):
- return f"tensor_name={self.tensor_name}, value={self.value}"
- def __repr__(self):
- return self.__str__()
- def _raw_add(self, other):
- z_value = (self.value + other)
- return self._boxed(z_value)
- def _raw_sub(self, other):
- z_value = (self.value - other)
- return self._boxed(z_value)
- def __add__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- return self._raw_add(other.value)
- else:
- return self._raw_add(other)
- def __radd__(self, other):
- return self.__add__(other)
- def __sub__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- return self._raw_sub(other.value)
- else:
- return self._raw_sub(other)
- def __rsub__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- z_value = other.value - self.value
- else:
- z_value = other - self.value
- return self._boxed(z_value)
- def __mul__(self, other):
- if isinstance(other, PaillierFixedPointTensor):
- raise NotImplementedError("__mul__ not support PaillierFixedPointTensor")
- elif isinstance(other, FixedPointTensor):
- return self._boxed(self.value * other.value)
- else:
- return self._boxed(self.value * other)
- def __rmul__(self, other):
- self.__mul__(other)
- def _boxed(self, value, tensor_name=None):
- return PaillierFixedPointTensor(value=value,
- tensor_name=tensor_name,
- cipher=self.cipher)
- @classmethod
- def from_source(cls, tensor_name, source, **kwargs):
- spdz = cls.get_spdz()
- q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
- if 'encoder' in kwargs:
- encoder = kwargs['encoder']
- else:
- base = kwargs['base'] if 'base' in kwargs else 10
- frac = kwargs['frac'] if 'frac' in kwargs else 4
- encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
- if isinstance(source, np.ndarray):
- _pre = urand_tensor(q_field, source)
- share = _pre
- spdz.communicator.remote_share(share=source - encoder.decode(_pre),
- tensor_name=tensor_name,
- party=spdz.other_parties[-1])
- return FixedPointTensor(value=share,
- q_field=q_field,
- endec=encoder,
- tensor_name=tensor_name)
- elif isinstance(source, Party):
- share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
- is_cipher_source = kwargs['is_cipher_source'] if 'is_cipher_source' in kwargs else True
- if is_cipher_source:
- cipher = kwargs['cipher']
- share = cipher.recursive_decrypt(share)
- share = encoder.encode(share)
- return FixedPointTensor(value=share,
- q_field=q_field,
- endec=encoder,
- tensor_name=tensor_name)
- else:
- raise ValueError(f"type={type(source)}")
|