paillier_tensor.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import numpy as np
  19. from federatedml.util import LOGGER
  20. from fate_arch.session import computing_session
  21. from fate_arch.abc import CTableABC
  22. class PaillierTensor(object):
  23. def __init__(self, obj, partitions=1):
  24. if obj is None:
  25. raise ValueError("Cannot convert None to Paillier tensor")
  26. if isinstance(obj, (list, np.ndarray)):
  27. self._ori_data = obj
  28. self._partitions = partitions
  29. self._obj = computing_session.parallelize(obj,
  30. include_key=False,
  31. partition=partitions)
  32. elif isinstance(obj, CTableABC):
  33. self._ori_data = None
  34. self._partitions = obj.partitions
  35. self._obj = obj
  36. else:
  37. raise ValueError(f"Cannot convert obj to Paillier tensor, object type is {type(obj)}")
  38. LOGGER.debug("tensor's partition is {}".format(self._partitions))
  39. def __add__(self, other):
  40. if isinstance(other, PaillierTensor):
  41. return PaillierTensor(self._obj.join(other._obj, lambda v1, v2: v1 + v2))
  42. elif isinstance(other, CTableABC):
  43. return PaillierTensor(self._obj.join(other, lambda v1, v2: v1 + v2))
  44. elif isinstance(other, (np.ndarray, int, float)):
  45. return PaillierTensor(self._obj.mapValues(lambda v: v + other))
  46. else:
  47. raise ValueError(f"Unrecognized type {type(other)}, dose not support subtraction")
  48. def __radd__(self, other):
  49. return self.__add__(other)
  50. def __sub__(self, other):
  51. if isinstance(other, PaillierTensor):
  52. return PaillierTensor(self._obj.join(other._obj, lambda v1, v2: v1 - v2))
  53. elif isinstance(other, CTableABC):
  54. return PaillierTensor(self._obj.join(other, lambda v1, v2: v1 - v2))
  55. elif isinstance(other, (np.ndarray, int, float)):
  56. return PaillierTensor(self._obj.mapValues(lambda v: v - other))
  57. else:
  58. raise ValueError(f"Unrecognized type {type(other)}, dose not support subtraction")
  59. def __rsub__(self, other):
  60. return self.__sub__(other)
  61. def __mul__(self, other):
  62. if isinstance(other, (int, float)):
  63. return PaillierTensor(self._obj.mapValues(lambda val: val * other))
  64. elif isinstance(other, np.ndarray):
  65. return PaillierTensor(self._obj.mapValues(lambda val: np.matmul(val, other)))
  66. elif isinstance(other, CTableABC):
  67. other = PaillierTensor(other)
  68. return self.__mul__(other)
  69. elif isinstance(other, PaillierTensor):
  70. ret = self.numpy() * other.numpy()
  71. return PaillierTensor(ret, partitions=max(self.partitions, other.partitions))
  72. def matmul(self, other):
  73. if isinstance(other, np.ndarray):
  74. if len(other.shape) != 2:
  75. raise ValueError("Only Support 2-D multiplication in matmul op, "
  76. "if you want to do 3-D, use fast_multiply_3d")
  77. return self.fast_matmul_2d(other)
  78. def multiply(self, other):
  79. if isinstance(other, np.ndarray):
  80. if other.shape != self.shape:
  81. raise ValueError(f"operands could not be broadcast together with shapes {self.shape} {other.shape}")
  82. rhs = PaillierTensor(other)
  83. return PaillierTensor(self.multiply(rhs))
  84. elif isinstance(other, CTableABC):
  85. other = PaillierTensor(other)
  86. return self.multiply(other)
  87. elif isinstance(other, PaillierTensor):
  88. return PaillierTensor(self._obj.join(other._obj, lambda v1, v2: v1 * v2))
  89. else:
  90. raise ValueError(f"Not support type in multiply op {type(other)}")
  91. @property
  92. def T(self):
  93. if self._ori_data is None:
  94. self._ori_data = self.numpy()
  95. new_data = self._ori_data.T
  96. return PaillierTensor(new_data, self.partitions)
  97. @property
  98. def partitions(self):
  99. return self._partitions
  100. def get_obj(self):
  101. return self._obj
  102. @property
  103. def shape(self):
  104. if self._ori_data is not None:
  105. return self._ori_data.shape
  106. else:
  107. first_dim = self._obj.count()
  108. if not first_dim:
  109. return (0, )
  110. second_dim = self._obj.first()[1].shape
  111. return tuple([first_dim] + list(second_dim))
  112. def mean(self, axis=-1):
  113. if axis == -1:
  114. size = 1
  115. for shape in self._ori_data.shape:
  116. size *= shape
  117. if not size:
  118. raise ValueError("shape of data is zero, it should be positive")
  119. return self._obj.mapValues(lambda val: np.sum(val)).reduce(lambda val1, val2: val1 + val2) / size
  120. else:
  121. ret_obj = self._obj.mapValues(lambda val: np.mean(val, axis - 1))
  122. return PaillierTensor(ret_obj)
  123. def reduce_sum(self):
  124. return self._obj.reduce(lambda t1, t2: t1 + t2)
  125. def map_ndarray_product(self, other):
  126. if isinstance(other, np.ndarray):
  127. return PaillierTensor(self._obj.mapValues(lambda val: val * other))
  128. else:
  129. raise ValueError('only support numpy array')
  130. def numpy(self):
  131. if self._ori_data is not None:
  132. return self._ori_data
  133. arr = [None for i in range(self._obj.count())]
  134. for k, v in self._obj.collect():
  135. arr[k] = v
  136. self._ori_data = np.array(arr, dtype=arr[0].dtype)
  137. return self._ori_data
  138. def encrypt(self, encrypt_tool):
  139. return PaillierTensor(encrypt_tool.distribute_encrypt(self._obj))
  140. def decrypt(self, decrypt_tool):
  141. return PaillierTensor(self._obj.mapValues(lambda val: decrypt_tool.recursive_decrypt(val)))
  142. def encode(self, encoder):
  143. return PaillierTensor(self._obj.mapValues(lambda val: encoder.encode(val)))
  144. def decode(self, decoder):
  145. return PaillierTensor(self._obj.mapValues(lambda val: decoder.decode(val)))
  146. @staticmethod
  147. def _vector_mul(kv_iters):
  148. ret_mat = None
  149. for k, v in kv_iters:
  150. tmp_mat = np.outer(v[0], v[1])
  151. if ret_mat is not None:
  152. ret_mat += tmp_mat
  153. else:
  154. ret_mat = tmp_mat
  155. return ret_mat
  156. def fast_matmul_2d(self, other):
  157. """
  158. Matrix multiplication between two matrix, please ensure that self's shape is (m, n) and other's shape is (m, k)
  159. Their result is a matrix of (n, k)
  160. """
  161. if isinstance(other, np.ndarray):
  162. mat_tensor = PaillierTensor(other, partitions=self.partitions)
  163. return self.fast_matmul_2d(mat_tensor)
  164. if isinstance(other, CTableABC):
  165. other = PaillierTensor(other)
  166. func = self._vector_mul
  167. ret_mat = self._obj.join(other.get_obj(), lambda vec1, vec2: (vec1, vec2)).applyPartitions(func).reduce(
  168. lambda mat1, mat2: mat1 + mat2)
  169. return ret_mat
  170. def matmul_3d(self, other, multiply='left'):
  171. assert multiply in ['left', 'right']
  172. if isinstance(other, PaillierTensor):
  173. mat = other
  174. elif isinstance(other, CTableABC):
  175. mat = PaillierTensor(other)
  176. elif isinstance(other, np.ndarray):
  177. mat = PaillierTensor(other, partitions=self.partitions)
  178. else:
  179. raise ValueError('only support numpy array and Paillier Tensor')
  180. if multiply == 'left':
  181. return PaillierTensor(self._obj.join(mat._obj, lambda val1, val2: np.tensordot(val1, val2, (1, 0))),
  182. partitions=self._partitions)
  183. if multiply == 'right':
  184. return PaillierTensor(mat._obj.join(self._obj, lambda val1, val2: np.tensordot(val1, val2, (1, 0))),
  185. partitions=self._partitions)
  186. def element_wise_product(self, other):
  187. if isinstance(other, np.ndarray):
  188. mat = PaillierTensor(other, partitions=self.partitions)
  189. elif isinstance(other, CTableABC):
  190. mat = PaillierTensor(other)
  191. else:
  192. mat = other
  193. return PaillierTensor(self._obj.join(mat._obj, lambda val1, val2: val1 * val2))
  194. def squeeze(self, axis):
  195. if axis == 0:
  196. return PaillierTensor(list(self._obj.collect())[0][1], partitions=self.partitions)
  197. else:
  198. return PaillierTensor(self._obj.mapValues(lambda val: np.squeeze(val, axis=axis - 1)))
  199. def select_columns(self, select_table):
  200. return PaillierTensor(self._obj.join(select_table, lambda v1, v2: v1[v2]))