fate_paillier.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. """Paillier encryption library for partially homomorphic encryption."""
  2. #
  3. # Copyright 2019 The FATE Authors. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import random
  18. from federatedml.secureprotol import gmpy_math
  19. from federatedml.secureprotol.fixedpoint import FixedPointNumber
  20. class PaillierKeypair(object):
  21. def __init__(self):
  22. pass
  23. @staticmethod
  24. def generate_keypair(n_length=1024):
  25. """return a new :class:`PaillierPublicKey` and :class:`PaillierPrivateKey`.
  26. """
  27. p = q = n = None
  28. n_len = 0
  29. while n_len != n_length:
  30. p = gmpy_math.getprimeover(n_length // 2)
  31. q = p
  32. while q == p:
  33. q = gmpy_math.getprimeover(n_length // 2)
  34. n = p * q
  35. n_len = n.bit_length()
  36. public_key = PaillierPublicKey(n)
  37. private_key = PaillierPrivateKey(public_key, p, q)
  38. return public_key, private_key
  39. class PaillierPublicKey(object):
  40. """Contains a public key and associated encryption methods.
  41. """
  42. def __init__(self, n):
  43. self.g = n + 1
  44. self.n = n
  45. self.nsquare = n * n
  46. self.max_int = n // 3 - 1
  47. def __repr__(self):
  48. hashcode = hex(hash(self))[2:]
  49. return "<PaillierPublicKey {}>".format(hashcode[:10])
  50. def __eq__(self, other):
  51. return self.n == other.n
  52. def __hash__(self):
  53. return hash(self.n)
  54. def apply_obfuscator(self, ciphertext, random_value=None):
  55. """
  56. """
  57. r = random_value or random.SystemRandom().randrange(1, self.n)
  58. obfuscator = gmpy_math.powmod(r, self.n, self.nsquare)
  59. return (ciphertext * obfuscator) % self.nsquare
  60. def raw_encrypt(self, plaintext, random_value=None):
  61. """
  62. """
  63. if not isinstance(plaintext, int):
  64. raise TypeError("plaintext should be int, but got: %s" %
  65. type(plaintext))
  66. if plaintext >= (self.n - self.max_int) and plaintext < self.n:
  67. # Very large plaintext, take a sneaky shortcut using inverses
  68. neg_plaintext = self.n - plaintext # = abs(plaintext - nsquare)
  69. neg_ciphertext = (self.n * neg_plaintext + 1) % self.nsquare
  70. ciphertext = gmpy_math.invert(neg_ciphertext, self.nsquare)
  71. else:
  72. ciphertext = (self.n * plaintext + 1) % self.nsquare
  73. ciphertext = self.apply_obfuscator(ciphertext, random_value)
  74. return ciphertext
  75. def encrypt(self, value, precision=None, random_value=None):
  76. """Encode and Paillier encrypt a real number value.
  77. """
  78. if isinstance(value, FixedPointNumber):
  79. value = value.decode()
  80. encoding = FixedPointNumber.encode(value, self.n, self.max_int, precision)
  81. obfuscator = random_value or 1
  82. ciphertext = self.raw_encrypt(encoding.encoding, random_value=obfuscator)
  83. encryptednumber = PaillierEncryptedNumber(self, ciphertext, encoding.exponent)
  84. if random_value is None:
  85. encryptednumber.apply_obfuscator()
  86. return encryptednumber
  87. class PaillierPrivateKey(object):
  88. """Contains a private key and associated decryption method.
  89. """
  90. def __init__(self, public_key, p, q):
  91. if not p * q == public_key.n:
  92. raise ValueError("given public key does not match the given p and q")
  93. if p == q:
  94. raise ValueError("p and q have to be different")
  95. self.public_key = public_key
  96. if q < p:
  97. self.p = q
  98. self.q = p
  99. else:
  100. self.p = p
  101. self.q = q
  102. self.psquare = self.p * self.p
  103. self.qsquare = self.q * self.q
  104. self.q_inverse = gmpy_math.invert(self.q, self.p)
  105. self.hp = self.h_func(self.p, self.psquare)
  106. self.hq = self.h_func(self.q, self.qsquare)
  107. def __eq__(self, other):
  108. return self.p == other.p and self.q == other.q
  109. def __hash__(self):
  110. return hash((self.p, self.q))
  111. def __repr__(self):
  112. hashcode = hex(hash(self))[2:]
  113. return "<PaillierPrivateKey {}>".format(hashcode[:10])
  114. def h_func(self, x, xsquare):
  115. """Computes the h-function as defined in Paillier's paper page.
  116. """
  117. return gmpy_math.invert(self.l_func(gmpy_math.powmod(self.public_key.g,
  118. x - 1, xsquare), x), x)
  119. def l_func(self, x, p):
  120. """computes the L function as defined in Paillier's paper.
  121. """
  122. return (x - 1) // p
  123. def crt(self, mp, mq):
  124. """the Chinese Remainder Theorem as needed for decryption.
  125. return the solution modulo n=pq.
  126. """
  127. u = (mp - mq) * self.q_inverse % self.p
  128. x = (mq + (u * self.q)) % self.public_key.n
  129. return x
  130. def raw_decrypt(self, ciphertext):
  131. """return raw plaintext.
  132. """
  133. if not isinstance(ciphertext, int):
  134. raise TypeError("ciphertext should be an int, not: %s" %
  135. type(ciphertext))
  136. mp = self.l_func(gmpy_math.powmod(ciphertext,
  137. self.p - 1, self.psquare),
  138. self.p) * self.hp % self.p
  139. mq = self.l_func(gmpy_math.powmod(ciphertext,
  140. self.q - 1, self.qsquare),
  141. self.q) * self.hq % self.q
  142. return self.crt(mp, mq)
  143. def decrypt(self, encrypted_number):
  144. """return the decrypted & decoded plaintext of encrypted_number.
  145. """
  146. if not isinstance(encrypted_number, PaillierEncryptedNumber):
  147. raise TypeError("encrypted_number should be an PaillierEncryptedNumber, \
  148. not: %s" % type(encrypted_number))
  149. if self.public_key != encrypted_number.public_key:
  150. raise ValueError("encrypted_number was encrypted against a different key!")
  151. encoded = self.raw_decrypt(encrypted_number.ciphertext(be_secure=False))
  152. encoded = FixedPointNumber(encoded,
  153. encrypted_number.exponent,
  154. self.public_key.n,
  155. self.public_key.max_int)
  156. decrypt_value = encoded.decode()
  157. return decrypt_value
  158. class PaillierEncryptedNumber(object):
  159. """Represents the Paillier encryption of a float or int.
  160. """
  161. def __init__(self, public_key, ciphertext, exponent=0):
  162. self.public_key = public_key
  163. self.__ciphertext = ciphertext
  164. self.exponent = exponent
  165. self.__is_obfuscator = False
  166. if not isinstance(self.__ciphertext, int):
  167. raise TypeError("ciphertext should be an int, not: %s" % type(self.__ciphertext))
  168. if not isinstance(self.public_key, PaillierPublicKey):
  169. raise TypeError("public_key should be a PaillierPublicKey, not: %s" % type(self.public_key))
  170. def ciphertext(self, be_secure=True):
  171. """return the ciphertext of the PaillierEncryptedNumber.
  172. """
  173. if be_secure and not self.__is_obfuscator:
  174. self.apply_obfuscator()
  175. return self.__ciphertext
  176. def apply_obfuscator(self):
  177. """ciphertext by multiplying by r ** n with random r
  178. """
  179. self.__ciphertext = self.public_key.apply_obfuscator(self.__ciphertext)
  180. self.__is_obfuscator = True
  181. def __add__(self, other):
  182. if isinstance(other, PaillierEncryptedNumber):
  183. return self.__add_encryptednumber(other)
  184. else:
  185. return self.__add_scalar(other)
  186. def __radd__(self, other):
  187. return self.__add__(other)
  188. def __sub__(self, other):
  189. return self + (other * -1)
  190. def __rsub__(self, other):
  191. return other + (self * -1)
  192. def __rmul__(self, scalar):
  193. return self.__mul__(scalar)
  194. def __truediv__(self, scalar):
  195. return self.__mul__(1 / scalar)
  196. def __mul__(self, scalar):
  197. """return Multiply by an scalar(such as int, float)
  198. """
  199. if isinstance(scalar, FixedPointNumber):
  200. scalar = scalar.decode()
  201. encode = FixedPointNumber.encode(scalar, self.public_key.n, self.public_key.max_int)
  202. plaintext = encode.encoding
  203. if plaintext < 0 or plaintext >= self.public_key.n:
  204. raise ValueError("Scalar out of bounds: %i" % plaintext)
  205. if plaintext >= self.public_key.n - self.public_key.max_int:
  206. # Very large plaintext, play a sneaky trick using inverses
  207. neg_c = gmpy_math.invert(self.ciphertext(False), self.public_key.nsquare)
  208. neg_scalar = self.public_key.n - plaintext
  209. ciphertext = gmpy_math.powmod(neg_c, neg_scalar, self.public_key.nsquare)
  210. else:
  211. ciphertext = gmpy_math.powmod(self.ciphertext(False), plaintext, self.public_key.nsquare)
  212. exponent = self.exponent + encode.exponent
  213. return PaillierEncryptedNumber(self.public_key, ciphertext, exponent)
  214. def increase_exponent_to(self, new_exponent):
  215. """return PaillierEncryptedNumber:
  216. new PaillierEncryptedNumber with same value but having great exponent.
  217. """
  218. if new_exponent < self.exponent:
  219. raise ValueError("New exponent %i should be great than old exponent %i" % (new_exponent, self.exponent))
  220. factor = pow(FixedPointNumber.BASE, new_exponent - self.exponent)
  221. new_encryptednumber = self.__mul__(factor)
  222. new_encryptednumber.exponent = new_exponent
  223. return new_encryptednumber
  224. def __align_exponent(self, x, y):
  225. """return x,y with same exponet
  226. """
  227. if x.exponent < y.exponent:
  228. x = x.increase_exponent_to(y.exponent)
  229. elif x.exponent > y.exponent:
  230. y = y.increase_exponent_to(x.exponent)
  231. return x, y
  232. def __add_scalar(self, scalar):
  233. """return PaillierEncryptedNumber: z = E(x) + y
  234. """
  235. if isinstance(scalar, FixedPointNumber):
  236. scalar = scalar.decode()
  237. encoded = FixedPointNumber.encode(scalar,
  238. self.public_key.n,
  239. self.public_key.max_int,
  240. max_exponent=self.exponent)
  241. return self.__add_fixpointnumber(encoded)
  242. def __add_fixpointnumber(self, encoded):
  243. """return PaillierEncryptedNumber: z = E(x) + FixedPointNumber(y)
  244. """
  245. if self.public_key.n != encoded.n:
  246. raise ValueError("Attempted to add numbers encoded against different public keys!")
  247. # their exponents must match, and align.
  248. x, y = self.__align_exponent(self, encoded)
  249. encrypted_scalar = x.public_key.raw_encrypt(y.encoding, 1)
  250. encryptednumber = self.__raw_add(x.ciphertext(False), encrypted_scalar, x.exponent)
  251. return encryptednumber
  252. def __add_encryptednumber(self, other):
  253. """return PaillierEncryptedNumber: z = E(x) + E(y)
  254. """
  255. if self.public_key != other.public_key:
  256. raise ValueError("add two numbers have different public key!")
  257. # their exponents must match, and align.
  258. x, y = self.__align_exponent(self, other)
  259. encryptednumber = self.__raw_add(x.ciphertext(False), y.ciphertext(False), x.exponent)
  260. return encryptednumber
  261. def __raw_add(self, e_x, e_y, exponent):
  262. """return the integer E(x + y) given ints E(x) and E(y).
  263. """
  264. ciphertext = gmpy_math.mpz(e_x) * gmpy_math.mpz(e_y) % self.public_key.nsquare
  265. return PaillierEncryptedNumber(self.public_key, int(ciphertext), exponent)