encrypt_mode.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. from federatedml.secureprotol import PaillierEncrypt
  18. from federatedml.util import LOGGER
  19. class EncryptModeCalculator(object):
  20. """
  21. Encyprt Mode module, a balance of security level and speed.
  22. Parameters
  23. ----------
  24. encrypter: object, fate-paillier object, object to encrypt numbers
  25. mode: str, accpet 'strict', 'fast', 'balance'. "confusion_opt", "confusion_opt_balance"
  26. 'strict': means that re-encrypted every function call.
  27. """
  28. def __init__(self, encrypter=None, mode="strict", re_encrypted_rate=1):
  29. self.encrypter = encrypter
  30. self.mode = mode
  31. self.re_encrypted_rate = re_encrypted_rate
  32. self.prev_data = None
  33. self.prev_encrypted_data = None
  34. self.enc_zeros = None
  35. self.align_to_input_data = True
  36. if self.mode != "strict":
  37. self.mode = "strict"
  38. LOGGER.warning("encrypted_mode_calculator will be remove in later version, "
  39. "but in current version user can still use it, but it only supports strict mode, "
  40. "other mode will be reset to strict for compatibility")
  41. @staticmethod
  42. def add_enc_zero(obj, enc_zero):
  43. pass
  44. def encrypt_data(self, input_data, enc_func):
  45. return input_data.mapValues(enc_func)
  46. def get_enc_func(self, encrypter, raw_enc=False, exponent=0):
  47. if not raw_enc:
  48. return encrypter.recursive_encrypt
  49. else:
  50. if isinstance(self.encrypter, PaillierEncrypt):
  51. raw_en_func = functools.partial(self.encrypter.recursive_raw_encrypt, exponent=exponent)
  52. else:
  53. raw_en_func = self.encrypter.recursive_raw_encrypt
  54. return raw_en_func
  55. def encrypt(self, input_data):
  56. """
  57. Encrypt data according to different mode
  58. Parameters
  59. ----------
  60. input_data: Table
  61. Returns
  62. -------
  63. new_data: Table, encrypted result of input_data
  64. """
  65. encrypt_func = self.get_enc_func(self.encrypter, raw_enc=False)
  66. new_data = self.encrypt_data(input_data, encrypt_func)
  67. return new_data
  68. def raw_encrypt(self, input_data, exponent=0):
  69. raw_en_func = self.get_enc_func(self.encrypter, raw_enc=True, exponent=exponent)
  70. new_data = self.encrypt_data(input_data, raw_en_func)
  71. return new_data
  72. def init_enc_zero(self, input_data, raw_en=False, exponent=0):
  73. pass
  74. def recursive_encrypt(self, input_data):
  75. return self.encrypter.recursive_encrypt(input_data)
  76. def distribute_encrypt(self, input_data):
  77. return self.encrypt(input_data)
  78. def distribute_decrypt(self, input_data):
  79. return self.encrypter.distribute_decrypt(input_data)
  80. def recursive_decrypt(self, input_data):
  81. return self.encrypter.recursive_decrypt(input_data)