fate_paillier_test.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import numpy as np
  17. import unittest
  18. from federatedml.secureprotol.fate_paillier import PaillierKeypair
  19. from federatedml.secureprotol.fate_paillier import PaillierPublicKey
  20. from federatedml.secureprotol.fate_paillier import PaillierPrivateKey
  21. from federatedml.secureprotol.fate_paillier import PaillierEncryptedNumber
  22. class TestPaillierEncryptedNumber(unittest.TestCase):
  23. def setUp(self):
  24. self.public_key, self.private_key = PaillierKeypair.generate_keypair()
  25. def tearDown(self):
  26. unittest.TestCase.tearDown(self)
  27. def test_add(self):
  28. x_li = np.ones(100) * np.random.randint(100)
  29. y_li = np.ones(100) * np.random.randint(1000)
  30. z_li = np.ones(100) * np.random.rand()
  31. t_li = range(100)
  32. for i in range(x_li.shape[0]):
  33. x = x_li[i]
  34. y = y_li[i]
  35. z = z_li[i]
  36. t = t_li[i]
  37. en_x = self.public_key.encrypt(x)
  38. en_y = self.public_key.encrypt(y)
  39. en_z = self.public_key.encrypt(z)
  40. en_t = self.public_key.encrypt(t)
  41. en_res = en_x + en_y + en_z + en_t
  42. res = x + y + z + t
  43. de_en_res = self.private_key.decrypt(en_res)
  44. self.assertAlmostEqual(de_en_res, res)
  45. def test_mul(self):
  46. x_li = np.ones(100) * np.random.randint(100)
  47. y_li = np.ones(100) * np.random.randint(1000) * -1
  48. z_li = np.ones(100) * np.random.rand()
  49. t_li = range(100)
  50. for i in range(x_li.shape[0]):
  51. x = x_li[i]
  52. y = y_li[i]
  53. z = z_li[i]
  54. t = t_li[i]
  55. en_x = self.public_key.encrypt(x)
  56. en_res = (en_x * y + z) * t
  57. res = (x * y + z) * t
  58. de_en_res = self.private_key.decrypt(en_res)
  59. self.assertAlmostEqual(de_en_res, res)
  60. x = 9
  61. en_x = self.public_key.encrypt(x)
  62. for i in range(100):
  63. en_x = en_x + 5000 - 0.2
  64. x = x + 5000 - 0.2
  65. de_en_x = self.private_key.decrypt(en_x)
  66. self.assertAlmostEqual(de_en_x, x)
  67. if __name__ == '__main__':
  68. unittest.main()