encrypt_mode_test.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import numpy as np
  18. import unittest
  19. class TestEncryptModeCalculator(unittest.TestCase):
  20. def setUp(self):
  21. from fate_arch.session import computing_session as session
  22. session.init("test_encrypt_mode_calculator")
  23. self.list_data = []
  24. self.tuple_data = []
  25. self.numpy_data = []
  26. for i in range(30):
  27. list_value = [100 * i + j for j in range(20)]
  28. tuple_value = tuple(list_value)
  29. numpy_value = np.array(list_value, dtype="int")
  30. self.list_data.append(list_value)
  31. self.tuple_data.append(tuple_value)
  32. self.numpy_data.append(numpy_value)
  33. self.data_list = session.parallelize(self.list_data, include_key=False, partition=10)
  34. self.data_tuple = session.parallelize(self.tuple_data, include_key=False, partition=10)
  35. self.data_numpy = session.parallelize(self.numpy_data, include_key=False, partition=10)
  36. def test_data_type(self, mode="strict", re_encrypted_rate=0.2):
  37. from federatedml.secureprotol import PaillierEncrypt
  38. from federatedml.secureprotol.encrypt_mode import EncryptModeCalculator
  39. encrypter = PaillierEncrypt()
  40. encrypter.generate_key(1024)
  41. encrypted_calculator = EncryptModeCalculator(encrypter, mode, re_encrypted_rate)
  42. data_list = dict(encrypted_calculator.encrypt(self.data_list).collect())
  43. data_tuple = dict(encrypted_calculator.encrypt(self.data_tuple).collect())
  44. data_numpy = dict(encrypted_calculator.encrypt(self.data_numpy).collect())
  45. for key, value in data_list.items():
  46. self.assertTrue(isinstance(value, list))
  47. self.assertTrue(len(value) == len(self.list_data[key]))
  48. for key, value in data_tuple.items():
  49. self.assertTrue(isinstance(value, tuple))
  50. self.assertTrue(len(value) == len(self.tuple_data[key]))
  51. for key, value in data_numpy.items():
  52. self.assertTrue(type(value).__name__ == "ndarray")
  53. self.assertTrue(value.shape[0] == self.numpy_data[key].shape[0])
  54. def test_data_type_with_diff_mode(self):
  55. mode_list = ["strict", "fast", "confusion_opt", "balance", "confusion_opt_balance"]
  56. for mode in mode_list:
  57. self.test_data_type(mode=mode)
  58. def test_diff_mode(self, round=10, mode="strict", re_encrypted_rate=0.2):
  59. from federatedml.secureprotol.encrypt_mode import EncryptModeCalculator
  60. from federatedml.secureprotol import PaillierEncrypt
  61. encrypter = PaillierEncrypt()
  62. encrypter.generate_key(1024)
  63. encrypted_calculator = EncryptModeCalculator(encrypter, mode, re_encrypted_rate)
  64. for i in range(round):
  65. data_i = self.data_numpy.mapValues(lambda v: v + i)
  66. data_i = encrypted_calculator.encrypt(data_i)
  67. decrypt_data_i = dict(data_i.mapValues(lambda arr: np.array(
  68. [encrypter.decrypt(val) for val in arr])).collect())
  69. for j in range(30):
  70. self.assertTrue(np.fabs(self.numpy_data[j] - decrypt_data_i[j] + i).all() < 1e-5)
  71. if __name__ == '__main__':
  72. unittest.main()