encrypt.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import hashlib
  18. from collections import Iterable
  19. import numpy as np
  20. from federatedml.util import LOGGER
  21. from Cryptodome import Random
  22. from Cryptodome.PublicKey import RSA
  23. from federatedml.feature.instance import Instance
  24. from federatedml.secureprotol import gmpy_math
  25. from federatedml.secureprotol.fate_paillier import PaillierKeypair
  26. from federatedml.secureprotol.fate_paillier import PaillierEncryptedNumber
  27. from federatedml.secureprotol.random import RandomPads
  28. try:
  29. from ipcl_python import PaillierKeypair as IpclPaillierKeypair
  30. except ImportError:
  31. pass
  32. _TORCH_VALID = False
  33. try:
  34. import torch
  35. _TORCH_VALID = True
  36. except ImportError:
  37. pass
  38. class Encrypt(object):
  39. def __init__(self):
  40. self.public_key = None
  41. self.privacy_key = None
  42. def generate_key(self, n_length=0):
  43. pass
  44. def set_public_key(self, public_key):
  45. pass
  46. def get_public_key(self):
  47. pass
  48. def set_privacy_key(self, privacy_key):
  49. pass
  50. def get_privacy_key(self):
  51. pass
  52. def encrypt(self, value):
  53. pass
  54. def decrypt(self, value):
  55. pass
  56. def raw_encrypt(self, value):
  57. pass
  58. def raw_decrypt(self, value):
  59. pass
  60. def encrypt_list(self, values):
  61. result = [self.encrypt(msg) for msg in values]
  62. return result
  63. def decrypt_list(self, values):
  64. result = [self.decrypt(msg) for msg in values]
  65. return result
  66. def distribute_decrypt(self, X):
  67. decrypt_table = X.mapValues(lambda x: self.recursive_decrypt(x))
  68. return decrypt_table
  69. def distribute_encrypt(self, X):
  70. encrypt_table = X.mapValues(lambda x: self.recursive_encrypt(x))
  71. return encrypt_table
  72. def distribute_raw_decrypt(self, X):
  73. return X.mapValues(lambda x: self.recursive_raw_decrypt(x))
  74. def distribute_raw_encrypt(self, X):
  75. return X.mapValues(lambda x: self.recursive_raw_encrypt(x))
  76. def _recursive_func(self, obj, func):
  77. if isinstance(obj, np.ndarray):
  78. if len(obj.shape) == 1:
  79. return np.reshape([func(val) for val in obj], obj.shape)
  80. else:
  81. return np.reshape(
  82. [self._recursive_func(o, func) for o in obj], obj.shape
  83. )
  84. elif isinstance(obj, Iterable):
  85. return type(obj)(
  86. self._recursive_func(o, func) if isinstance(o, Iterable) else func(o)
  87. for o in obj
  88. )
  89. else:
  90. return func(obj)
  91. def recursive_encrypt(self, X):
  92. return self._recursive_func(X, self.encrypt)
  93. def recursive_decrypt(self, X):
  94. return self._recursive_func(X, self.decrypt)
  95. def recursive_raw_encrypt(self, X):
  96. return self._recursive_func(X, self.raw_encrypt)
  97. def recursive_raw_decrypt(self, X):
  98. return self._recursive_func(X, self.raw_decrypt)
  99. class RsaEncrypt(Encrypt):
  100. def __init__(self):
  101. super(RsaEncrypt, self).__init__()
  102. self.e = None
  103. self.d = None
  104. self.n = None
  105. self.p = None
  106. self.q = None
  107. def generate_key(self, rsa_bit=1024):
  108. random_generator = Random.new().read
  109. rsa = RSA.generate(rsa_bit, random_generator)
  110. self.e = rsa.e
  111. self.d = rsa.d
  112. self.n = rsa.n
  113. self.p = rsa.p
  114. self.q = rsa.q
  115. def get_key_pair(self):
  116. return self.e, self.d, self.n, self.p, self.q
  117. def set_public_key(self, public_key):
  118. self.e = public_key["e"]
  119. self.n = public_key["n"]
  120. def get_public_key(self):
  121. return self.e, self.n
  122. def set_privacy_key(self, privacy_key):
  123. self.d = privacy_key["d"]
  124. self.n = privacy_key["n"]
  125. def get_privacy_key(self):
  126. return self.d, self.n
  127. def encrypt(self, value):
  128. if self.e is not None and self.n is not None and self.p is not None and self.q is not None:
  129. cp, cq = gmpy_math.crt_coefficient(self.p, self.q)
  130. return gmpy_math.powmod_crt(value, self.e, self.n, self.p, self.q, cp, cq)
  131. if self.e is not None and self.n is not None:
  132. return gmpy_math.powmod(value, self.e, self.n)
  133. else:
  134. return None
  135. def decrypt(self, value):
  136. if self.d is not None and self.n is not None:
  137. return gmpy_math.powmod(value, self.d, self.n)
  138. else:
  139. return None
  140. class PaillierEncrypt(Encrypt):
  141. def __init__(self):
  142. super(PaillierEncrypt, self).__init__()
  143. def generate_key(self, n_length=1024):
  144. self.public_key, self.privacy_key = PaillierKeypair.generate_keypair(
  145. n_length=n_length
  146. )
  147. def get_key_pair(self):
  148. return self.public_key, self.privacy_key
  149. def set_public_key(self, public_key):
  150. self.public_key = public_key
  151. def get_public_key(self):
  152. return self.public_key
  153. def set_privacy_key(self, privacy_key):
  154. self.privacy_key = privacy_key
  155. def get_privacy_key(self):
  156. return self.privacy_key
  157. def encrypt(self, value):
  158. if self.public_key is not None:
  159. return self.public_key.encrypt(value)
  160. else:
  161. return None
  162. def decrypt(self, value):
  163. if self.privacy_key is not None:
  164. return self.privacy_key.decrypt(value)
  165. else:
  166. return None
  167. def raw_encrypt(self, plaintext, exponent=0):
  168. cipher_int = self.public_key.raw_encrypt(plaintext)
  169. paillier_num = PaillierEncryptedNumber(public_key=self.public_key, ciphertext=cipher_int, exponent=exponent)
  170. return paillier_num
  171. def raw_decrypt(self, ciphertext):
  172. return self.privacy_key.raw_decrypt(ciphertext.ciphertext())
  173. def recursive_raw_encrypt(self, X, exponent=0):
  174. raw_en_func = functools.partial(self.raw_encrypt, exponent=exponent)
  175. return self._recursive_func(X, raw_en_func)
  176. class IpclPaillierEncrypt(Encrypt):
  177. """
  178. A class to perform Paillier encryption with Intel Paillier Cryptosystem Library (IPCL)
  179. """
  180. def __init__(self):
  181. super(IpclPaillierEncrypt, self).__init__()
  182. def generate_key(self, n_length=1024):
  183. self.public_key, self.privacy_key = IpclPaillierKeypair.generate_keypair(
  184. n_length=n_length
  185. )
  186. def get_key_pair(self):
  187. return self.public_key, self.privacy_key
  188. def set_public_key(self, public_key):
  189. self.public_key = public_key
  190. def get_public_key(self):
  191. return self.public_key
  192. def set_privacy_key(self, privacy_key):
  193. self.privacy_key = privacy_key
  194. def get_privacy_key(self):
  195. return self.privacy_key
  196. def encrypt(self, value):
  197. if self.public_key is not None:
  198. return self.public_key.encrypt(value)
  199. else:
  200. return None
  201. def decrypt(self, value):
  202. if self.privacy_key is not None:
  203. return self.privacy_key.decrypt(value)
  204. else:
  205. return None
  206. def raw_encrypt(self, plaintext, exponent=0):
  207. """
  208. Encrypt without applying obfuscator.
  209. Returns:
  210. (PaillierEncryptedNumber from `ipcl_python`): one ciphertext
  211. """
  212. return self.public_key.raw_encrypt(plaintext)
  213. def raw_decrypt(self, ciphertext):
  214. """
  215. Decrypt without constructing `FixedPointNumber`.
  216. Returns:
  217. (int or list): raw value(s)
  218. """
  219. return self.privacy_key.raw_decrypt(ciphertext)
  220. def encrypt_list(self, values):
  221. """Encrypt a list of raw values into one ciphertext.
  222. Returns:
  223. (PaillierEncryptedNumber from `ipcl_python`): all in one single ciphertext
  224. """
  225. return self.encrypt(values)
  226. def decrypt_list(self, values):
  227. """
  228. Decrypt input values.
  229. If the type is list or 1-d numpy array, use `decrypt_list` of the parent class.
  230. Ohterwise, the type will be a 0-d numpy array, which contains one single ciphertext of multiple raw values.
  231. Use `item(0)` to fetch the ciphertext and then decrypt.
  232. Returns:
  233. (list): a list of raw values
  234. """
  235. if np.ndim(values) >= 1:
  236. return super().decrypt_list(values)
  237. return self.decrypt(values.item(0))
  238. def recursive_raw_encrypt(self, X, exponent=0):
  239. raw_en_func = functools.partial(self.raw_encrypt, exponent=exponent)
  240. return self._recursive_func(X, raw_en_func)
  241. class PadsCipher(Encrypt):
  242. def __init__(self):
  243. super().__init__()
  244. self._uuid = None
  245. self._rands = None
  246. self._amplify_factor = 1
  247. def set_self_uuid(self, uuid):
  248. self._uuid = uuid
  249. def set_amplify_factor(self, factor):
  250. self._amplify_factor = factor
  251. def set_exchanged_keys(self, keys):
  252. self._seeds = {
  253. uid: v & 0xFFFFFFFF for uid, v in keys.items() if uid != self._uuid
  254. }
  255. self._rands = {
  256. uid: RandomPads(v & 0xFFFFFFFF)
  257. for uid, v in keys.items()
  258. if uid != self._uuid
  259. }
  260. def encrypt(self, value):
  261. if isinstance(value, np.ndarray):
  262. ret = value
  263. for uid, rand in self._rands.items():
  264. if uid > self._uuid:
  265. ret = rand.add_rand_pads(ret, 1.0 * self._amplify_factor)
  266. else:
  267. ret = rand.add_rand_pads(ret, -1.0 * self._amplify_factor)
  268. return ret
  269. if _TORCH_VALID and isinstance(value, torch.Tensor):
  270. ret = value.numpy()
  271. for uid, rand in self._rands.items():
  272. if uid > self._uuid:
  273. ret = rand.add_rand_pads(ret, 1.0 * self._amplify_factor)
  274. else:
  275. ret = rand.add_rand_pads(ret, -1.0 * self._amplify_factor)
  276. return torch.Tensor(ret)
  277. ret = value
  278. for uid, rand in self._rands.items():
  279. if uid > self._uuid:
  280. ret += rand.rand(1)[0] * self._amplify_factor
  281. else:
  282. ret -= rand.rand(1)[0] * self._amplify_factor
  283. return ret
  284. def encrypt_table(self, table):
  285. def _pad(key, value, seeds, amplify_factor):
  286. has_key = int(hashlib.md5(f"{key}".encode("ascii")).hexdigest(), 16)
  287. # LOGGER.debug(f"hash_key: {has_key}")
  288. cur_seeds = {uid: has_key + seed for uid, seed in seeds.items()}
  289. # LOGGER.debug(f"cur_seeds: {cur_seeds}")
  290. rands = {uid: RandomPads(v & 0xFFFFFFFF) for uid, v in cur_seeds.items()}
  291. if isinstance(value, np.ndarray):
  292. ret = value
  293. for uid, rand in rands.items():
  294. if uid > self._uuid:
  295. ret = rand.add_rand_pads(ret, 1.0 * amplify_factor)
  296. else:
  297. ret = rand.add_rand_pads(ret, -1.0 * amplify_factor)
  298. return key, ret
  299. elif isinstance(value, Instance):
  300. ret = value.features
  301. for uid, rand in rands.items():
  302. if uid > self._uuid:
  303. ret = rand.add_rand_pads(ret, 1.0 * amplify_factor)
  304. else:
  305. ret = rand.add_rand_pads(ret, -1.0 * amplify_factor)
  306. value.features = ret
  307. return key, value
  308. else:
  309. ret = value
  310. for uid, rand in rands.items():
  311. if uid > self._uuid:
  312. ret += rand.rand(1)[0] * self._amplify_factor
  313. else:
  314. ret -= rand.rand(1)[0] * self._amplify_factor
  315. return key, ret
  316. f = functools.partial(
  317. _pad, seeds=self._seeds, amplify_factor=self._amplify_factor
  318. )
  319. return table.map(f)
  320. def decrypt(self, value):
  321. return value