123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import operator
- from collections import Iterable
- import functools
- import numpy as np
- from fate_arch.common import Party
- from fate_arch.session import is_table
- from federatedml.secureprotol.spdz.beaver_triples import beaver_triplets
- from federatedml.secureprotol.spdz.tensor import fixedpoint_numpy
- from federatedml.secureprotol.spdz.tensor.base import TensorBase
- from federatedml.secureprotol.spdz.utils import NamingService
- from federatedml.secureprotol.spdz.utils import urand_tensor
- # from federatedml.secureprotol.spdz.tensor.fixedpoint_endec import FixedPointEndec
- from federatedml.secureprotol.fixedpoint import FixedPointEndec
- def _table_binary_op(x, y, op):
- return x.join(y, lambda a, b: op(a, b))
- def _table_binary_mod_op(x, y, q_field, op):
- return x.join(y, lambda a, b: op(a, b) % q_field)
- def _table_scalar_op(x, d, op):
- return x.mapValues(lambda a: op(a, d))
- def _table_scalar_mod_op(x, d, q_field, op):
- return x.mapValues(lambda a: op(a, d) % q_field)
- def _table_dot_mod_func(it, q_field):
- ret = None
- for _, (x, y) in it:
- if ret is None:
- ret = np.tensordot(x, y, [[], []]) % q_field
- else:
- ret = (ret + np.tensordot(x, y, [[], []])) % q_field
- return ret
- def _table_dot_func(it):
- ret = None
- for _, (x, y) in it:
- if ret is None:
- ret = np.tensordot(x, y, [[], []])
- else:
- ret += np.tensordot(x, y, [[], []])
- return ret
- def table_dot(a_table, b_table):
- return a_table.join(b_table, lambda x, y: [x, y]) \
- .applyPartitions(lambda it: _table_dot_func(it)) \
- .reduce(lambda x, y: x + y)
- def table_dot_mod(a_table, b_table, q_field):
- return a_table.join(b_table, lambda x, y: [x, y]) \
- .applyPartitions(lambda it: _table_dot_mod_func(it, q_field)) \
- .reduce(lambda x, y: x if y is None else y if x is None else x + y)
- class FixedPointTensor(TensorBase):
- """
- a table based tensor
- """
- __array_ufunc__ = None
- def __init__(self, value, q_field, endec, tensor_name: str = None):
- super().__init__(q_field, tensor_name)
- self.value = value
- self.endec = endec
- self.tensor_name = NamingService.get_instance().next() if tensor_name is None else tensor_name
- def dot(self, other: 'FixedPointTensor', target_name=None):
- spdz = self.get_spdz()
- if target_name is None:
- target_name = NamingService.get_instance().next()
- a, b, c = beaver_triplets(a_tensor=self.value, b_tensor=other.value, dot=table_dot,
- q_field=self.q_field, he_key_pair=(spdz.public_key, spdz.private_key),
- communicator=spdz.communicator, name=target_name)
- x_add_a = (self + a).rescontruct(f"{target_name}_confuse_x")
- y_add_b = (other + b).rescontruct(f"{target_name}_confuse_y")
- cross = c - table_dot_mod(a, y_add_b, self.q_field) - table_dot_mod(x_add_a, b, self.q_field)
- if spdz.party_idx == 0:
- cross += table_dot_mod(x_add_a, y_add_b, self.q_field)
- cross = cross % self.q_field
- cross = self.endec.truncate(cross, self.get_spdz().party_idx)
- share = fixedpoint_numpy.FixedPointTensor(cross, self.q_field, self.endec, target_name)
- return share
- def dot_local(self, other, target_name=None):
- def _vec_dot(x, y, party_idx, q_field, endec):
- ret = np.dot(x, y) % q_field
- ret = endec.truncate(ret, party_idx)
- if not isinstance(ret, np.ndarray):
- ret = np.array([ret])
- return ret
- if isinstance(other, FixedPointTensor) or isinstance(other, fixedpoint_numpy.FixedPointTensor):
- other = other.value
- if isinstance(other, np.ndarray):
- party_idx = self.get_spdz().party_idx
- f = functools.partial(_vec_dot, y=other,
- party_idx=party_idx,
- q_field=self.q_field,
- endec=self.endec)
- ret = self.value.mapValues(f)
- return self._boxed(ret, target_name)
- elif is_table(other):
- ret = table_dot_mod(self.value, other, self.q_field).reshape((1, -1))[0]
- ret = self.endec.truncate(ret, self.get_spdz().party_idx)
- return fixedpoint_numpy.FixedPointTensor(ret,
- self.q_field,
- self.endec,
- target_name)
- else:
- raise ValueError(f"type={type(other)}")
- def reduce(self, func, **kwargs):
- ret = self.value.reduce(func)
- return fixedpoint_numpy.FixedPointTensor(ret,
- self.q_field,
- self.endec
- )
- @property
- def shape(self):
- return self.value.count(), len(self.value.first()[1])
- @classmethod
- def from_source(cls, tensor_name, source, **kwargs):
- spdz = cls.get_spdz()
- q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
- if 'encoder' in kwargs:
- encoder = kwargs['encoder']
- else:
- base = kwargs['base'] if 'base' in kwargs else 10
- frac = kwargs['frac'] if 'frac' in kwargs else 4
- encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
- if is_table(source):
- source = encoder.encode(source)
- _pre = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand)
- spdz.communicator.remote_share(share=_pre, tensor_name=tensor_name, party=spdz.other_parties[0])
- for _party in spdz.other_parties[1:]:
- r = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand)
- spdz.communicator.remote_share(share=_table_binary_mod_op(r, _pre, q_field, operator.sub),
- tensor_name=tensor_name, party=_party)
- _pre = r
- share = _table_binary_mod_op(source, _pre, q_field, operator.sub)
- elif isinstance(source, Party):
- share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
- else:
- raise ValueError(f"type={type(source)}")
- return FixedPointTensor(share, q_field, encoder, tensor_name)
- def get(self, tensor_name=None, broadcast=True):
- return self.endec.decode(self.rescontruct(tensor_name, broadcast))
- def rescontruct(self, tensor_name=None, broadcast=True):
- from federatedml.secureprotol.spdz import SPDZ
- spdz = SPDZ.get_instance()
- share_val = self.value.copy()
- name = tensor_name or self.tensor_name
- if name is None:
- raise ValueError("name not specified")
- # remote share to other parties
- if broadcast:
- spdz.communicator.broadcast_rescontruct_share(share_val, name)
- # get shares from other parties
- for other_share in spdz.communicator.get_rescontruct_shares(name):
- share_val = _table_binary_mod_op(share_val, other_share, self.q_field, operator.add)
- return share_val
- def broadcast_reconstruct_share(self, tensor_name=None):
- from federatedml.secureprotol.spdz import SPDZ
- spdz = SPDZ.get_instance()
- share_val = self.value.copy()
- name = tensor_name or self.tensor_name
- if name is None:
- raise ValueError("name not specified")
- # remote share to other parties
- spdz.communicator.broadcast_rescontruct_share(share_val, name)
- return share_val
- def __str__(self):
- return f"tensor_name={self.tensor_name}, value={self.value}"
- def __repr__(self):
- return self.__str__()
- def as_name(self, tensor_name):
- return self._boxed(value=self.value, tensor_name=tensor_name)
- def __add__(self, other):
- if isinstance(other, PaillierFixedPointTensor):
- z_value = _table_binary_op(self.value, other.value, operator.add)
- return PaillierFixedPointTensor(z_value)
- elif isinstance(other, FixedPointTensor):
- z_value = _table_binary_mod_op(self.value, other.value, self.q_field, operator.add)
- elif is_table(other):
- z_value = _table_binary_mod_op(self.value, other, self.q_field, operator.add)
- else:
- z_value = _table_scalar_mod_op(self.value, other, self.q_field, operator.add)
- return self._boxed(z_value)
- def __radd__(self, other):
- return self.__add__(other)
- def __sub__(self, other):
- if isinstance(other, PaillierFixedPointTensor):
- z_value = _table_binary_op(self.value, other.value, operator.sub)
- return PaillierFixedPointTensor(z_value)
- elif isinstance(other, FixedPointTensor):
- z_value = _table_binary_mod_op(self.value, other.value, self.q_field, operator.sub)
- elif is_table(other):
- z_value = _table_binary_mod_op(self.value, other, self.q_field, operator.sub)
- else:
- z_value = _table_scalar_mod_op(self.value, other, self.q_field, operator.sub)
- return self._boxed(z_value)
- def __rsub__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- return other - self
- elif is_table(other):
- z_value = _table_binary_mod_op(other, self.value, self.q_field, operator.sub)
- else:
- z_value = _table_scalar_mod_op(self.value, other, self.q_field, -1 * operator.sub)
- return self._boxed(z_value)
- def __mul__(self, other):
- if isinstance(other, FixedPointTensor):
- z_value = _table_binary_mod_op(self.value, other.value, self.q_field, operator.mul)
- elif isinstance(other, PaillierFixedPointTensor):
- z_value = _table_binary_op(self.value, other.value, operator.mul)
- else:
- z_value = _table_scalar_mod_op(self.value, other, self.q_field, operator.mul)
- z_value = self.endec.truncate(z_value, self.get_spdz().party_idx)
- return self._boxed(z_value)
- def __rmul__(self, other):
- return self.__mul__(other)
- def __mod__(self, other):
- if not isinstance(other, (int, np.integer)):
- raise NotImplementedError("__mod__ support integer only")
- return self._boxed(_table_scalar_op(self.value, other, operator.mod))
- def _boxed(self, value, tensor_name=None):
- return FixedPointTensor(value=value, q_field=self.q_field, endec=self.endec, tensor_name=tensor_name)
- class PaillierFixedPointTensor(TensorBase):
- __array_ufunc__ = None
- def __init__(self, value, tensor_name: str = None, cipher=None):
- super().__init__(q_field=None, tensor_name=tensor_name)
- self.value = value
- self.cipher = cipher
- def dot(self, other, target_name=None):
- def _vec_dot(x, y):
- ret = np.dot(x, y)
- if not isinstance(ret, np.ndarray):
- ret = np.array([ret])
- return ret
- if isinstance(other, (FixedPointTensor, fixedpoint_numpy.FixedPointTensor)):
- other = other.value
- if isinstance(other, np.ndarray):
- ret = self.value.mapValues(lambda x: _vec_dot(x, other))
- return self._boxed(ret, target_name)
- elif is_table(other):
- ret = table_dot(self.value, other).reshape((1, -1))[0]
- return fixedpoint_numpy.PaillierFixedPointTensor(ret, target_name)
- else:
- raise ValueError(f"type={type(other)}")
- def reduce(self, func, **kwargs):
- ret = self.value.reduce(func)
- return fixedpoint_numpy.PaillierFixedPointTensor(ret)
- def __str__(self):
- return f"tensor_name={self.tensor_name}, value={self.value}"
- def __repr__(self):
- return self.__str__()
- def __add__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- return self._boxed(_table_binary_op(self.value, other.value, operator.add))
- elif is_table(other):
- return self._boxed(_table_binary_op(self.value, other, operator.add))
- else:
- return self._boxed(_table_scalar_op(self.value, other, operator.add))
- def __radd__(self, other):
- return self.__add__(other)
- def __sub__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- return self._boxed(_table_binary_op(self.value, other.value, operator.sub))
- elif is_table(other):
- return self._boxed(_table_binary_op(self.value, other, operator.sub))
- else:
- return self._boxed(_table_scalar_op(self.value, other, operator.sub))
- def __rsub__(self, other):
- if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
- return self._boxed(_table_binary_op(other.value, self.value, operator.sub))
- elif is_table(other):
- return self._boxed(_table_binary_op(other, self.value, operator.sub))
- else:
- return self._boxed(_table_scalar_op(self.value, other, -1 * operator.sub))
- def __mul__(self, other):
- if isinstance(other, FixedPointTensor):
- z_value = _table_binary_op(self.value, other.value, operator.mul)
- elif is_table(other):
- z_value = _table_binary_op(self.value, other, operator.mul)
- else:
- z_value = _table_scalar_op(self.value, other, operator.mul)
- return self._boxed(z_value)
- def __rmul__(self, other):
- return self.__mul__(other)
- def _boxed(self, value, tensor_name=None):
- return PaillierFixedPointTensor(value=value, tensor_name=tensor_name)
- @classmethod
- def from_source(cls, tensor_name, source, **kwargs):
- spdz = cls.get_spdz()
- q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
- if 'encoder' in kwargs:
- encoder = kwargs['encoder']
- else:
- base = kwargs['base'] if 'base' in kwargs else 10
- frac = kwargs['frac'] if 'frac' in kwargs else 4
- encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
- if is_table(source):
- _pre = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand)
- share = _pre
- spdz.communicator.remote_share(share=_table_binary_op(source, encoder.decode(_pre), operator.sub),
- tensor_name=tensor_name, party=spdz.other_parties[-1])
- return FixedPointTensor(value=share,
- q_field=q_field,
- endec=encoder,
- tensor_name=tensor_name)
- elif isinstance(source, Party):
- share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
- is_cipher_source = kwargs['is_cipher_source'] if 'is_cipher_source' in kwargs else True
- if is_cipher_source:
- cipher = kwargs.get("cipher")
- if cipher is None:
- raise ValueError("Cipher is not provided")
- share = cipher.distribute_decrypt(share)
- share = encoder.encode(share)
- return FixedPointTensor(value=share,
- q_field=q_field,
- endec=encoder,
- tensor_name=tensor_name)
- else:
- raise ValueError(f"type={type(source)}")
|