fixedpoint_numpy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import numpy as np
  18. from fate_arch.common import Party
  19. from fate_arch.computing import is_table
  20. from federatedml.secureprotol.spdz.beaver_triples import beaver_triplets
  21. from federatedml.secureprotol.spdz.tensor import fixedpoint_table
  22. from federatedml.secureprotol.spdz.tensor.base import TensorBase
  23. from federatedml.secureprotol.spdz.utils import urand_tensor
  24. # from federatedml.secureprotol.spdz.tensor.fixedpoint_endec import FixedPointEndec
  25. from federatedml.secureprotol.fixedpoint import FixedPointEndec
  26. from federatedml.util import LOGGER
  27. class FixedPointTensor(TensorBase):
  28. __array_ufunc__ = None
  29. def __init__(self, value, q_field, endec, tensor_name: str = None):
  30. super().__init__(q_field, tensor_name)
  31. self.endec = endec
  32. self.value = value
  33. @property
  34. def shape(self):
  35. return self.value.shape
  36. def reshape(self, shape):
  37. return self._boxed(self.value.reshape(shape))
  38. def dot(self, other, target_name=None):
  39. return self.einsum(other, "ij,ik->jk", target_name)
  40. def dot_local(self, other, target_name=None):
  41. if isinstance(other, FixedPointTensor):
  42. other = other.value
  43. ret = np.dot(self.value, other) % self.q_field
  44. ret = self.endec.truncate(ret, self.get_spdz().party_idx)
  45. if not isinstance(ret, np.ndarray):
  46. ret = np.array([ret])
  47. return self._boxed(ret, target_name)
  48. def sub_matrix(self, tensor_name: str, row_indices=None, col_indices=None, rm_row_indices=None,
  49. rm_col_indices=None):
  50. if row_indices is not None:
  51. x_indices = list(row_indices)
  52. elif row_indices is None and rm_row_indices is not None:
  53. x_indices = [i for i in range(self.value.shape[0]) if i not in rm_row_indices]
  54. else:
  55. raise RuntimeError(f"invalid argument")
  56. if col_indices is not None:
  57. y_indices = list(col_indices)
  58. elif row_indices is None and rm_col_indices is not None:
  59. y_indices = [i for i in range(self.value.shape[0]) if i not in rm_col_indices]
  60. else:
  61. raise RuntimeError(f"invalid argument")
  62. value = self.value[x_indices, :][:, y_indices]
  63. return FixedPointTensor(value=value, q_field=self.q_field, endec=self.endec, tensor_name=tensor_name)
  64. @classmethod
  65. def from_source(cls, tensor_name, source, **kwargs):
  66. spdz = cls.get_spdz()
  67. q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
  68. if 'encoder' in kwargs:
  69. encoder = kwargs['encoder']
  70. else:
  71. base = kwargs['base'] if 'base' in kwargs else 10
  72. frac = kwargs['frac'] if 'frac' in kwargs else 4
  73. encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
  74. if isinstance(source, np.ndarray):
  75. source = encoder.encode(source)
  76. _pre = urand_tensor(q_field, source)
  77. spdz.communicator.remote_share(share=_pre, tensor_name=tensor_name, party=spdz.other_parties[0])
  78. for _party in spdz.other_parties[1:]:
  79. r = urand_tensor(q_field, source)
  80. spdz.communicator.remote_share(share=(r - _pre) % q_field, tensor_name=tensor_name, party=_party)
  81. _pre = r
  82. share = (source - _pre) % q_field
  83. elif isinstance(source, Party):
  84. share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
  85. else:
  86. raise ValueError(f"type={type(source)}")
  87. return FixedPointTensor(share, q_field, encoder, tensor_name)
  88. def einsum(self, other: 'FixedPointTensor', einsum_expr, target_name=None):
  89. spdz = self.get_spdz()
  90. target_name = target_name or spdz.name_service.next()
  91. def _dot_func(_x, _y):
  92. ret = np.dot(_x, _y)
  93. if not isinstance(ret, np.ndarray):
  94. ret = np.array([ret])
  95. return ret
  96. # return np.einsum(einsum_expr, _x, _y, optimize=True)
  97. a, b, c = beaver_triplets(a_tensor=self.value, b_tensor=other.value, dot=_dot_func,
  98. q_field=self.q_field, he_key_pair=(spdz.public_key, spdz.private_key),
  99. communicator=spdz.communicator, name=target_name)
  100. x_add_a = self._raw_add(a).reconstruct(f"{target_name}_confuse_x")
  101. y_add_b = other._raw_add(b).reconstruct(f"{target_name}_confuse_y")
  102. cross = c - _dot_func(a, y_add_b) - _dot_func(x_add_a, b)
  103. if spdz.party_idx == 0:
  104. cross += _dot_func(x_add_a, y_add_b)
  105. cross = cross % self.q_field
  106. cross = self.endec.truncate(cross, self.get_spdz().party_idx)
  107. share = self._boxed(cross, tensor_name=target_name)
  108. return share
  109. def get(self, tensor_name=None, broadcast=True):
  110. return self.endec.decode(self.reconstruct(tensor_name, broadcast))
  111. def reconstruct(self, tensor_name=None, broadcast=True):
  112. from federatedml.secureprotol.spdz import SPDZ
  113. spdz = SPDZ.get_instance()
  114. share_val = self.value.copy()
  115. name = tensor_name or self.tensor_name
  116. if name is None:
  117. raise ValueError("name not specified")
  118. # remote share to other parties
  119. if broadcast:
  120. spdz.communicator.broadcast_rescontruct_share(share_val, name)
  121. # get shares from other parties
  122. for other_share in spdz.communicator.get_rescontruct_shares(name):
  123. # LOGGER.debug(f"share_val: {share_val}, other_share: {other_share}")
  124. share_val += other_share
  125. try:
  126. share_val %= self.q_field
  127. return share_val
  128. except BaseException:
  129. return share_val
  130. def transpose(self):
  131. value = self.value.transpose()
  132. return self._boxed(value)
  133. def broadcast_reconstruct_share(self, tensor_name=None):
  134. from federatedml.secureprotol.spdz import SPDZ
  135. spdz = SPDZ.get_instance()
  136. share_val = self.value.copy()
  137. name = tensor_name or self.tensor_name
  138. if name is None:
  139. raise ValueError("name not specified")
  140. # remote share to other parties
  141. spdz.communicator.broadcast_rescontruct_share(share_val, name)
  142. return share_val
  143. def _boxed(self, value, tensor_name=None):
  144. return FixedPointTensor(value=value, q_field=self.q_field, endec=self.endec, tensor_name=tensor_name)
  145. def __str__(self):
  146. return f"tensor_name={self.tensor_name}, value={self.value}"
  147. def __repr__(self):
  148. return self.__str__()
  149. def as_name(self, tensor_name):
  150. return self._boxed(value=self.value, tensor_name=tensor_name)
  151. def _raw_add(self, other):
  152. z_value = (self.value + other) % self.q_field
  153. return self._boxed(z_value)
  154. def _raw_sub(self, other):
  155. z_value = (self.value - other) % self.q_field
  156. return self._boxed(z_value)
  157. def __add__(self, other):
  158. if isinstance(other, PaillierFixedPointTensor):
  159. z_value = (self.value + other)
  160. return PaillierFixedPointTensor(z_value)
  161. elif isinstance(other, FixedPointTensor):
  162. return self._raw_add(other.value)
  163. z_value = (self.value + other) % self.q_field
  164. return self._boxed(z_value)
  165. def __radd__(self, other):
  166. return self.__add__(other)
  167. def __sub__(self, other):
  168. if isinstance(other, PaillierFixedPointTensor):
  169. z_value = (self.value - other)
  170. return PaillierFixedPointTensor(z_value)
  171. elif isinstance(other, FixedPointTensor):
  172. return self._raw_sub(other.value)
  173. z_value = (self.value - other) % self.q_field
  174. return self._boxed(z_value)
  175. def __rsub__(self, other):
  176. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  177. return other - self
  178. z_value = (other - self.value) % self.q_field
  179. return self._boxed(z_value)
  180. def __mul__(self, other):
  181. if isinstance(other, PaillierFixedPointTensor):
  182. z_value = self.value * other.value
  183. return PaillierFixedPointTensor(z_value)
  184. if isinstance(other, FixedPointTensor):
  185. other = other.value
  186. z_value = self.value * other
  187. z_value = z_value % self.q_field
  188. z_value = self.endec.truncate(z_value, self.get_spdz().party_idx)
  189. return self._boxed(z_value)
  190. def __rmul__(self, other):
  191. return self.__mul__(other)
  192. def __matmul__(self, other):
  193. return self.einsum(other, "ij,jk->ik")
  194. class PaillierFixedPointTensor(TensorBase):
  195. __array_ufunc__ = None
  196. def __init__(self, value, tensor_name: str = None, cipher=None):
  197. super().__init__(q_field=None, tensor_name=tensor_name)
  198. self.value = value
  199. self.cipher = cipher
  200. def dot(self, other, target_name=None):
  201. def _vec_dot(x, y):
  202. ret = np.dot(x, y)
  203. if not isinstance(ret, np.ndarray):
  204. ret = np.array([ret])
  205. return ret
  206. if isinstance(other, (FixedPointTensor, fixedpoint_table.FixedPointTensor)):
  207. other = other.value
  208. if isinstance(other, np.ndarray):
  209. ret = _vec_dot(self.value, other)
  210. return self._boxed(ret, target_name)
  211. elif is_table(other):
  212. f = functools.partial(_vec_dot,
  213. self.value)
  214. ret = other.mapValues(f)
  215. return fixedpoint_table.PaillierFixedPointTensor(value=ret,
  216. tensor_name=target_name,
  217. cipher=self.cipher)
  218. else:
  219. raise ValueError(f"type={type(other)}")
  220. def broadcast_reconstruct_share(self, tensor_name=None):
  221. from federatedml.secureprotol.spdz import SPDZ
  222. spdz = SPDZ.get_instance()
  223. share_val = self.value.copy()
  224. name = tensor_name or self.tensor_name
  225. if name is None:
  226. raise ValueError("name not specified")
  227. # remote share to other parties
  228. spdz.communicator.broadcast_rescontruct_share(share_val, name)
  229. return share_val
  230. def __str__(self):
  231. return f"tensor_name={self.tensor_name}, value={self.value}"
  232. def __repr__(self):
  233. return self.__str__()
  234. def _raw_add(self, other):
  235. z_value = (self.value + other)
  236. return self._boxed(z_value)
  237. def _raw_sub(self, other):
  238. z_value = (self.value - other)
  239. return self._boxed(z_value)
  240. def __add__(self, other):
  241. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  242. return self._raw_add(other.value)
  243. else:
  244. return self._raw_add(other)
  245. def __radd__(self, other):
  246. return self.__add__(other)
  247. def __sub__(self, other):
  248. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  249. return self._raw_sub(other.value)
  250. else:
  251. return self._raw_sub(other)
  252. def __rsub__(self, other):
  253. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  254. z_value = other.value - self.value
  255. else:
  256. z_value = other - self.value
  257. return self._boxed(z_value)
  258. def __mul__(self, other):
  259. if isinstance(other, PaillierFixedPointTensor):
  260. raise NotImplementedError("__mul__ not support PaillierFixedPointTensor")
  261. elif isinstance(other, FixedPointTensor):
  262. return self._boxed(self.value * other.value)
  263. else:
  264. return self._boxed(self.value * other)
  265. def __rmul__(self, other):
  266. self.__mul__(other)
  267. def _boxed(self, value, tensor_name=None):
  268. return PaillierFixedPointTensor(value=value,
  269. tensor_name=tensor_name,
  270. cipher=self.cipher)
  271. @classmethod
  272. def from_source(cls, tensor_name, source, **kwargs):
  273. spdz = cls.get_spdz()
  274. q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
  275. if 'encoder' in kwargs:
  276. encoder = kwargs['encoder']
  277. else:
  278. base = kwargs['base'] if 'base' in kwargs else 10
  279. frac = kwargs['frac'] if 'frac' in kwargs else 4
  280. encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
  281. if isinstance(source, np.ndarray):
  282. _pre = urand_tensor(q_field, source)
  283. share = _pre
  284. spdz.communicator.remote_share(share=source - encoder.decode(_pre),
  285. tensor_name=tensor_name,
  286. party=spdz.other_parties[-1])
  287. return FixedPointTensor(value=share,
  288. q_field=q_field,
  289. endec=encoder,
  290. tensor_name=tensor_name)
  291. elif isinstance(source, Party):
  292. share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
  293. is_cipher_source = kwargs['is_cipher_source'] if 'is_cipher_source' in kwargs else True
  294. if is_cipher_source:
  295. cipher = kwargs['cipher']
  296. share = cipher.recursive_decrypt(share)
  297. share = encoder.encode(share)
  298. return FixedPointTensor(value=share,
  299. q_field=q_field,
  300. endec=encoder,
  301. tensor_name=tensor_name)
  302. else:
  303. raise ValueError(f"type={type(source)}")