fixedpoint_table.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import operator
  17. from collections import Iterable
  18. import functools
  19. import numpy as np
  20. from fate_arch.common import Party
  21. from fate_arch.session import is_table
  22. from federatedml.secureprotol.spdz.beaver_triples import beaver_triplets
  23. from federatedml.secureprotol.spdz.tensor import fixedpoint_numpy
  24. from federatedml.secureprotol.spdz.tensor.base import TensorBase
  25. from federatedml.secureprotol.spdz.utils import NamingService
  26. from federatedml.secureprotol.spdz.utils import urand_tensor
  27. # from federatedml.secureprotol.spdz.tensor.fixedpoint_endec import FixedPointEndec
  28. from federatedml.secureprotol.fixedpoint import FixedPointEndec
  29. def _table_binary_op(x, y, op):
  30. return x.join(y, lambda a, b: op(a, b))
  31. def _table_binary_mod_op(x, y, q_field, op):
  32. return x.join(y, lambda a, b: op(a, b) % q_field)
  33. def _table_scalar_op(x, d, op):
  34. return x.mapValues(lambda a: op(a, d))
  35. def _table_scalar_mod_op(x, d, q_field, op):
  36. return x.mapValues(lambda a: op(a, d) % q_field)
  37. def _table_dot_mod_func(it, q_field):
  38. ret = None
  39. for _, (x, y) in it:
  40. if ret is None:
  41. ret = np.tensordot(x, y, [[], []]) % q_field
  42. else:
  43. ret = (ret + np.tensordot(x, y, [[], []])) % q_field
  44. return ret
  45. def _table_dot_func(it):
  46. ret = None
  47. for _, (x, y) in it:
  48. if ret is None:
  49. ret = np.tensordot(x, y, [[], []])
  50. else:
  51. ret += np.tensordot(x, y, [[], []])
  52. return ret
  53. def table_dot(a_table, b_table):
  54. return a_table.join(b_table, lambda x, y: [x, y]) \
  55. .applyPartitions(lambda it: _table_dot_func(it)) \
  56. .reduce(lambda x, y: x + y)
  57. def table_dot_mod(a_table, b_table, q_field):
  58. return a_table.join(b_table, lambda x, y: [x, y]) \
  59. .applyPartitions(lambda it: _table_dot_mod_func(it, q_field)) \
  60. .reduce(lambda x, y: x if y is None else y if x is None else x + y)
  61. class FixedPointTensor(TensorBase):
  62. """
  63. a table based tensor
  64. """
  65. __array_ufunc__ = None
  66. def __init__(self, value, q_field, endec, tensor_name: str = None):
  67. super().__init__(q_field, tensor_name)
  68. self.value = value
  69. self.endec = endec
  70. self.tensor_name = NamingService.get_instance().next() if tensor_name is None else tensor_name
  71. def dot(self, other: 'FixedPointTensor', target_name=None):
  72. spdz = self.get_spdz()
  73. if target_name is None:
  74. target_name = NamingService.get_instance().next()
  75. a, b, c = beaver_triplets(a_tensor=self.value, b_tensor=other.value, dot=table_dot,
  76. q_field=self.q_field, he_key_pair=(spdz.public_key, spdz.private_key),
  77. communicator=spdz.communicator, name=target_name)
  78. x_add_a = (self + a).rescontruct(f"{target_name}_confuse_x")
  79. y_add_b = (other + b).rescontruct(f"{target_name}_confuse_y")
  80. cross = c - table_dot_mod(a, y_add_b, self.q_field) - table_dot_mod(x_add_a, b, self.q_field)
  81. if spdz.party_idx == 0:
  82. cross += table_dot_mod(x_add_a, y_add_b, self.q_field)
  83. cross = cross % self.q_field
  84. cross = self.endec.truncate(cross, self.get_spdz().party_idx)
  85. share = fixedpoint_numpy.FixedPointTensor(cross, self.q_field, self.endec, target_name)
  86. return share
  87. def dot_local(self, other, target_name=None):
  88. def _vec_dot(x, y, party_idx, q_field, endec):
  89. ret = np.dot(x, y) % q_field
  90. ret = endec.truncate(ret, party_idx)
  91. if not isinstance(ret, np.ndarray):
  92. ret = np.array([ret])
  93. return ret
  94. if isinstance(other, FixedPointTensor) or isinstance(other, fixedpoint_numpy.FixedPointTensor):
  95. other = other.value
  96. if isinstance(other, np.ndarray):
  97. party_idx = self.get_spdz().party_idx
  98. f = functools.partial(_vec_dot, y=other,
  99. party_idx=party_idx,
  100. q_field=self.q_field,
  101. endec=self.endec)
  102. ret = self.value.mapValues(f)
  103. return self._boxed(ret, target_name)
  104. elif is_table(other):
  105. ret = table_dot_mod(self.value, other, self.q_field).reshape((1, -1))[0]
  106. ret = self.endec.truncate(ret, self.get_spdz().party_idx)
  107. return fixedpoint_numpy.FixedPointTensor(ret,
  108. self.q_field,
  109. self.endec,
  110. target_name)
  111. else:
  112. raise ValueError(f"type={type(other)}")
  113. def reduce(self, func, **kwargs):
  114. ret = self.value.reduce(func)
  115. return fixedpoint_numpy.FixedPointTensor(ret,
  116. self.q_field,
  117. self.endec
  118. )
  119. @property
  120. def shape(self):
  121. return self.value.count(), len(self.value.first()[1])
  122. @classmethod
  123. def from_source(cls, tensor_name, source, **kwargs):
  124. spdz = cls.get_spdz()
  125. q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
  126. if 'encoder' in kwargs:
  127. encoder = kwargs['encoder']
  128. else:
  129. base = kwargs['base'] if 'base' in kwargs else 10
  130. frac = kwargs['frac'] if 'frac' in kwargs else 4
  131. encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
  132. if is_table(source):
  133. source = encoder.encode(source)
  134. _pre = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand)
  135. spdz.communicator.remote_share(share=_pre, tensor_name=tensor_name, party=spdz.other_parties[0])
  136. for _party in spdz.other_parties[1:]:
  137. r = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand)
  138. spdz.communicator.remote_share(share=_table_binary_mod_op(r, _pre, q_field, operator.sub),
  139. tensor_name=tensor_name, party=_party)
  140. _pre = r
  141. share = _table_binary_mod_op(source, _pre, q_field, operator.sub)
  142. elif isinstance(source, Party):
  143. share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
  144. else:
  145. raise ValueError(f"type={type(source)}")
  146. return FixedPointTensor(share, q_field, encoder, tensor_name)
  147. def get(self, tensor_name=None, broadcast=True):
  148. return self.endec.decode(self.rescontruct(tensor_name, broadcast))
  149. def rescontruct(self, tensor_name=None, broadcast=True):
  150. from federatedml.secureprotol.spdz import SPDZ
  151. spdz = SPDZ.get_instance()
  152. share_val = self.value.copy()
  153. name = tensor_name or self.tensor_name
  154. if name is None:
  155. raise ValueError("name not specified")
  156. # remote share to other parties
  157. if broadcast:
  158. spdz.communicator.broadcast_rescontruct_share(share_val, name)
  159. # get shares from other parties
  160. for other_share in spdz.communicator.get_rescontruct_shares(name):
  161. share_val = _table_binary_mod_op(share_val, other_share, self.q_field, operator.add)
  162. return share_val
  163. def broadcast_reconstruct_share(self, tensor_name=None):
  164. from federatedml.secureprotol.spdz import SPDZ
  165. spdz = SPDZ.get_instance()
  166. share_val = self.value.copy()
  167. name = tensor_name or self.tensor_name
  168. if name is None:
  169. raise ValueError("name not specified")
  170. # remote share to other parties
  171. spdz.communicator.broadcast_rescontruct_share(share_val, name)
  172. return share_val
  173. def __str__(self):
  174. return f"tensor_name={self.tensor_name}, value={self.value}"
  175. def __repr__(self):
  176. return self.__str__()
  177. def as_name(self, tensor_name):
  178. return self._boxed(value=self.value, tensor_name=tensor_name)
  179. def __add__(self, other):
  180. if isinstance(other, PaillierFixedPointTensor):
  181. z_value = _table_binary_op(self.value, other.value, operator.add)
  182. return PaillierFixedPointTensor(z_value)
  183. elif isinstance(other, FixedPointTensor):
  184. z_value = _table_binary_mod_op(self.value, other.value, self.q_field, operator.add)
  185. elif is_table(other):
  186. z_value = _table_binary_mod_op(self.value, other, self.q_field, operator.add)
  187. else:
  188. z_value = _table_scalar_mod_op(self.value, other, self.q_field, operator.add)
  189. return self._boxed(z_value)
  190. def __radd__(self, other):
  191. return self.__add__(other)
  192. def __sub__(self, other):
  193. if isinstance(other, PaillierFixedPointTensor):
  194. z_value = _table_binary_op(self.value, other.value, operator.sub)
  195. return PaillierFixedPointTensor(z_value)
  196. elif isinstance(other, FixedPointTensor):
  197. z_value = _table_binary_mod_op(self.value, other.value, self.q_field, operator.sub)
  198. elif is_table(other):
  199. z_value = _table_binary_mod_op(self.value, other, self.q_field, operator.sub)
  200. else:
  201. z_value = _table_scalar_mod_op(self.value, other, self.q_field, operator.sub)
  202. return self._boxed(z_value)
  203. def __rsub__(self, other):
  204. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  205. return other - self
  206. elif is_table(other):
  207. z_value = _table_binary_mod_op(other, self.value, self.q_field, operator.sub)
  208. else:
  209. z_value = _table_scalar_mod_op(self.value, other, self.q_field, -1 * operator.sub)
  210. return self._boxed(z_value)
  211. def __mul__(self, other):
  212. if isinstance(other, FixedPointTensor):
  213. z_value = _table_binary_mod_op(self.value, other.value, self.q_field, operator.mul)
  214. elif isinstance(other, PaillierFixedPointTensor):
  215. z_value = _table_binary_op(self.value, other.value, operator.mul)
  216. else:
  217. z_value = _table_scalar_mod_op(self.value, other, self.q_field, operator.mul)
  218. z_value = self.endec.truncate(z_value, self.get_spdz().party_idx)
  219. return self._boxed(z_value)
  220. def __rmul__(self, other):
  221. return self.__mul__(other)
  222. def __mod__(self, other):
  223. if not isinstance(other, (int, np.integer)):
  224. raise NotImplementedError("__mod__ support integer only")
  225. return self._boxed(_table_scalar_op(self.value, other, operator.mod))
  226. def _boxed(self, value, tensor_name=None):
  227. return FixedPointTensor(value=value, q_field=self.q_field, endec=self.endec, tensor_name=tensor_name)
  228. class PaillierFixedPointTensor(TensorBase):
  229. __array_ufunc__ = None
  230. def __init__(self, value, tensor_name: str = None, cipher=None):
  231. super().__init__(q_field=None, tensor_name=tensor_name)
  232. self.value = value
  233. self.cipher = cipher
  234. def dot(self, other, target_name=None):
  235. def _vec_dot(x, y):
  236. ret = np.dot(x, y)
  237. if not isinstance(ret, np.ndarray):
  238. ret = np.array([ret])
  239. return ret
  240. if isinstance(other, (FixedPointTensor, fixedpoint_numpy.FixedPointTensor)):
  241. other = other.value
  242. if isinstance(other, np.ndarray):
  243. ret = self.value.mapValues(lambda x: _vec_dot(x, other))
  244. return self._boxed(ret, target_name)
  245. elif is_table(other):
  246. ret = table_dot(self.value, other).reshape((1, -1))[0]
  247. return fixedpoint_numpy.PaillierFixedPointTensor(ret, target_name)
  248. else:
  249. raise ValueError(f"type={type(other)}")
  250. def reduce(self, func, **kwargs):
  251. ret = self.value.reduce(func)
  252. return fixedpoint_numpy.PaillierFixedPointTensor(ret)
  253. def __str__(self):
  254. return f"tensor_name={self.tensor_name}, value={self.value}"
  255. def __repr__(self):
  256. return self.__str__()
  257. def __add__(self, other):
  258. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  259. return self._boxed(_table_binary_op(self.value, other.value, operator.add))
  260. elif is_table(other):
  261. return self._boxed(_table_binary_op(self.value, other, operator.add))
  262. else:
  263. return self._boxed(_table_scalar_op(self.value, other, operator.add))
  264. def __radd__(self, other):
  265. return self.__add__(other)
  266. def __sub__(self, other):
  267. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  268. return self._boxed(_table_binary_op(self.value, other.value, operator.sub))
  269. elif is_table(other):
  270. return self._boxed(_table_binary_op(self.value, other, operator.sub))
  271. else:
  272. return self._boxed(_table_scalar_op(self.value, other, operator.sub))
  273. def __rsub__(self, other):
  274. if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)):
  275. return self._boxed(_table_binary_op(other.value, self.value, operator.sub))
  276. elif is_table(other):
  277. return self._boxed(_table_binary_op(other, self.value, operator.sub))
  278. else:
  279. return self._boxed(_table_scalar_op(self.value, other, -1 * operator.sub))
  280. def __mul__(self, other):
  281. if isinstance(other, FixedPointTensor):
  282. z_value = _table_binary_op(self.value, other.value, operator.mul)
  283. elif is_table(other):
  284. z_value = _table_binary_op(self.value, other, operator.mul)
  285. else:
  286. z_value = _table_scalar_op(self.value, other, operator.mul)
  287. return self._boxed(z_value)
  288. def __rmul__(self, other):
  289. return self.__mul__(other)
  290. def _boxed(self, value, tensor_name=None):
  291. return PaillierFixedPointTensor(value=value, tensor_name=tensor_name)
  292. @classmethod
  293. def from_source(cls, tensor_name, source, **kwargs):
  294. spdz = cls.get_spdz()
  295. q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field
  296. if 'encoder' in kwargs:
  297. encoder = kwargs['encoder']
  298. else:
  299. base = kwargs['base'] if 'base' in kwargs else 10
  300. frac = kwargs['frac'] if 'frac' in kwargs else 4
  301. encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac)
  302. if is_table(source):
  303. _pre = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand)
  304. share = _pre
  305. spdz.communicator.remote_share(share=_table_binary_op(source, encoder.decode(_pre), operator.sub),
  306. tensor_name=tensor_name, party=spdz.other_parties[-1])
  307. return FixedPointTensor(value=share,
  308. q_field=q_field,
  309. endec=encoder,
  310. tensor_name=tensor_name)
  311. elif isinstance(source, Party):
  312. share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0]
  313. is_cipher_source = kwargs['is_cipher_source'] if 'is_cipher_source' in kwargs else True
  314. if is_cipher_source:
  315. cipher = kwargs.get("cipher")
  316. if cipher is None:
  317. raise ValueError("Cipher is not provided")
  318. share = cipher.distribute_decrypt(share)
  319. share = encoder.encode(share)
  320. return FixedPointTensor(value=share,
  321. q_field=q_field,
  322. endec=encoder,
  323. tensor_name=tensor_name)
  324. else:
  325. raise ValueError(f"type={type(source)}")