fixedpoint.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import math
  18. import sys
  19. import numpy as np
  20. class FixedPointNumber(object):
  21. """Represents a float or int fixedpoint encoding;.
  22. """
  23. BASE = 16
  24. LOG2_BASE = math.log(BASE, 2)
  25. FLOAT_MANTISSA_BITS = sys.float_info.mant_dig
  26. Q = 293973345475167247070445277780365744413 ** 2
  27. def __init__(self, encoding, exponent, n=None, max_int=None):
  28. if n is None:
  29. self.n = FixedPointNumber.Q
  30. self.max_int = self.n // 2
  31. else:
  32. self.n = n
  33. if max_int is None:
  34. self.max_int = self.n // 2
  35. else:
  36. self.max_int = max_int
  37. self.encoding = encoding
  38. self.exponent = exponent
  39. @classmethod
  40. def calculate_exponent_from_precision(cls, precision):
  41. exponent = math.floor(math.log(precision, cls.BASE))
  42. return exponent
  43. @classmethod
  44. def encode(cls, scalar, n=None, max_int=None, precision=None, max_exponent=None):
  45. """return an encoding of an int or float.
  46. """
  47. # Calculate the maximum exponent for desired precision
  48. exponent = None
  49. # Too low value preprocess;
  50. # avoid "OverflowError: int too large to convert to float"
  51. if np.abs(scalar) < 1e-200:
  52. scalar = 0
  53. if n is None:
  54. n = cls.Q
  55. max_int = n // 2
  56. if precision is None:
  57. if isinstance(scalar, int) or isinstance(scalar, np.int16) or \
  58. isinstance(scalar, np.int32) or isinstance(scalar, np.int64):
  59. exponent = 0
  60. elif isinstance(scalar, float) or isinstance(scalar, np.float16) \
  61. or isinstance(scalar, np.float32) or isinstance(scalar, np.float64):
  62. flt_exponent = math.frexp(scalar)[1]
  63. lsb_exponent = cls.FLOAT_MANTISSA_BITS - flt_exponent
  64. exponent = math.floor(lsb_exponent / cls.LOG2_BASE)
  65. else:
  66. raise TypeError("Don't know the precision of type %s."
  67. % type(scalar))
  68. else:
  69. exponent = cls.calculate_exponent_from_precision(precision)
  70. if max_exponent is not None:
  71. exponent = max(max_exponent, exponent)
  72. int_fixpoint = int(round(scalar * pow(cls.BASE, exponent)))
  73. if abs(int_fixpoint) > max_int:
  74. raise ValueError(f"Integer needs to be within +/- {max_int},but got {int_fixpoint},"
  75. f"basic info, scalar={scalar}, base={cls.BASE}, exponent={exponent}"
  76. )
  77. return cls(int_fixpoint % n, exponent, n, max_int)
  78. def decode(self):
  79. """return decode plaintext.
  80. """
  81. if self.encoding >= self.n:
  82. # Should be mod n
  83. raise ValueError('Attempted to decode corrupted number')
  84. elif self.encoding <= self.max_int:
  85. # Positive
  86. mantissa = self.encoding
  87. elif self.encoding >= self.n - self.max_int:
  88. # Negative
  89. mantissa = self.encoding - self.n
  90. else:
  91. raise OverflowError(f'Overflow detected in decode number, encoding: {self.encoding},'
  92. f'{self.exponent}'
  93. f' {self.n}')
  94. return mantissa * pow(self.BASE, -self.exponent)
  95. def increase_exponent_to(self, new_exponent):
  96. """return FixedPointNumber: new encoding with same value but having great exponent.
  97. """
  98. if new_exponent < self.exponent:
  99. raise ValueError('New exponent %i should be greater than'
  100. 'old exponent %i' % (new_exponent, self.exponent))
  101. factor = pow(self.BASE, new_exponent - self.exponent)
  102. new_encoding = self.encoding * factor % self.n
  103. return FixedPointNumber(new_encoding, new_exponent, self.n, self.max_int)
  104. def __align_exponent(self, x, y):
  105. """return x,y with same exponent
  106. """
  107. if x.exponent < y.exponent:
  108. x = x.increase_exponent_to(y.exponent)
  109. elif x.exponent > y.exponent:
  110. y = y.increase_exponent_to(x.exponent)
  111. return x, y
  112. def __truncate(self, a):
  113. scalar = a.decode()
  114. return FixedPointNumber.encode(scalar, n=self.n, max_int=self.max_int)
  115. def __add__(self, other):
  116. if isinstance(other, FixedPointNumber):
  117. return self.__add_fixedpointnumber(other)
  118. elif type(other).__name__ == "PaillierEncryptedNumber":
  119. return other + self.decode()
  120. else:
  121. return self.__add_scalar(other)
  122. def __radd__(self, other):
  123. return self.__add__(other)
  124. def __sub__(self, other):
  125. if isinstance(other, FixedPointNumber):
  126. return self.__sub_fixedpointnumber(other)
  127. elif type(other).__name__ == "PaillierEncryptedNumber":
  128. return (other - self.decode()) * -1
  129. else:
  130. return self.__sub_scalar(other)
  131. def __rsub__(self, other):
  132. if type(other).__name__ == "PaillierEncryptedNumber":
  133. return other - self.decode()
  134. x = self.__sub__(other)
  135. x = -1 * x.decode()
  136. return self.encode(x, n=self.n, max_int=self.max_int)
  137. def __rmul__(self, other):
  138. return self.__mul__(other)
  139. def __mul__(self, other):
  140. if isinstance(other, FixedPointNumber):
  141. return self.__mul_fixedpointnumber(other)
  142. elif type(other).__name__ == "PaillierEncryptedNumber":
  143. return other * self.decode()
  144. else:
  145. return self.__mul_scalar(other)
  146. def __truediv__(self, other):
  147. if isinstance(other, FixedPointNumber):
  148. scalar = other.decode()
  149. else:
  150. scalar = other
  151. return self.__mul__(1 / scalar)
  152. def __rtruediv__(self, other):
  153. res = 1.0 / self.__truediv__(other).decode()
  154. return FixedPointNumber.encode(res, n=self.n, max_int=self.max_int)
  155. def __lt__(self, other):
  156. x = self.decode()
  157. if isinstance(other, FixedPointNumber):
  158. y = other.decode()
  159. else:
  160. y = other
  161. if x < y:
  162. return True
  163. else:
  164. return False
  165. def __gt__(self, other):
  166. x = self.decode()
  167. if isinstance(other, FixedPointNumber):
  168. y = other.decode()
  169. else:
  170. y = other
  171. if x > y:
  172. return True
  173. else:
  174. return False
  175. def __le__(self, other):
  176. x = self.decode()
  177. if isinstance(other, FixedPointNumber):
  178. y = other.decode()
  179. else:
  180. y = other
  181. if x <= y:
  182. return True
  183. else:
  184. return False
  185. def __ge__(self, other):
  186. x = self.decode()
  187. if isinstance(other, FixedPointNumber):
  188. y = other.decode()
  189. else:
  190. y = other
  191. if x >= y:
  192. return True
  193. else:
  194. return False
  195. def __eq__(self, other):
  196. x = self.decode()
  197. if isinstance(other, FixedPointNumber):
  198. y = other.decode()
  199. else:
  200. y = other
  201. if x == y:
  202. return True
  203. else:
  204. return False
  205. def __ne__(self, other):
  206. x = self.decode()
  207. if isinstance(other, FixedPointNumber):
  208. y = other.decode()
  209. else:
  210. y = other
  211. if x != y:
  212. return True
  213. else:
  214. return False
  215. def __add_fixedpointnumber(self, other):
  216. if self.n != other.n:
  217. other = self.encode(other.decode(), n=self.n, max_int=self.max_int)
  218. x, y = self.__align_exponent(self, other)
  219. encoding = (x.encoding + y.encoding) % self.n
  220. return FixedPointNumber(encoding, x.exponent, n=self.n, max_int=self.max_int)
  221. def __add_scalar(self, scalar):
  222. encoded = self.encode(scalar, n=self.n, max_int=self.max_int)
  223. return self.__add_fixedpointnumber(encoded)
  224. def __sub_fixedpointnumber(self, other):
  225. if self.n != other.n:
  226. other = self.encode(other.decode(), n=self.n, max_int=self.max_int)
  227. x, y = self.__align_exponent(self, other)
  228. encoding = (x.encoding - y.encoding) % self.n
  229. return FixedPointNumber(encoding, x.exponent, n=self.n, max_int=self.max_int)
  230. def __sub_scalar(self, scalar):
  231. scalar = -1 * scalar
  232. return self.__add_scalar(scalar)
  233. def __mul_fixedpointnumber(self, other):
  234. return self.__mul_scalar(other.decode())
  235. def __mul_scalar(self, scalar):
  236. val = self.decode()
  237. z = val * scalar
  238. z_encode = FixedPointNumber.encode(z, n=self.n, max_int=self.max_int)
  239. return z_encode
  240. def __abs__(self):
  241. if self.encoding <= self.max_int:
  242. # Positive
  243. return self
  244. elif self.encoding >= self.n - self.max_int:
  245. # Negative
  246. return self * -1
  247. def __mod__(self, other):
  248. return FixedPointNumber(self.encoding % other, self.exponent, n=self.n, max_int=self.max_int)
  249. class FixedPointEndec(object):
  250. def __init__(self, n=None, max_int=None, precision=None, *args, **kwargs):
  251. if n is None:
  252. self.n = FixedPointNumber.Q
  253. self.max_int = self.n // 2
  254. else:
  255. self.n = n
  256. if max_int is None:
  257. self.max_int = self.n // 2
  258. else:
  259. self.max_int = max_int
  260. self.precision = precision
  261. @classmethod
  262. def _transform_op(cls, tensor, op):
  263. from fate_arch.session import is_table
  264. def _transform(x):
  265. arr = np.zeros(shape=x.shape, dtype=object)
  266. view = arr.view().reshape(-1)
  267. x_array = x.view().reshape(-1)
  268. for i in range(arr.size):
  269. view[i] = op(x_array[i])
  270. return arr
  271. if isinstance(tensor, (int, np.int16, np.int32, np.int64,
  272. float, np.float16, np.float32, np.float64,
  273. FixedPointNumber)):
  274. return op(tensor)
  275. if isinstance(tensor, np.ndarray):
  276. z = _transform(tensor)
  277. return z
  278. elif is_table(tensor):
  279. f = functools.partial(_transform)
  280. return tensor.mapValues(f)
  281. else:
  282. raise ValueError(f"unsupported type: {type(tensor)}")
  283. def _encode(self, scalar):
  284. return FixedPointNumber.encode(scalar,
  285. n=self.n,
  286. max_int=self.max_int,
  287. precision=self.precision)
  288. def _decode(self, number):
  289. return number.decode()
  290. def _truncate(self, number):
  291. scalar = number.decode()
  292. return FixedPointNumber.encode(scalar, n=self.n, max_int=self.max_int)
  293. def encode(self, float_tensor):
  294. return self._transform_op(float_tensor, op=self._encode)
  295. def decode(self, integer_tensor):
  296. return self._transform_op(integer_tensor, op=self._decode)
  297. def truncate(self, integer_tensor, *args, **kwargs):
  298. return self._transform_op(integer_tensor, op=self._truncate)