sir_param.py 4.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.param.base_param import BaseParam, deprecated_param
  19. from federatedml.param.base_param import BaseParam
  20. from federatedml.param.intersect_param import DHParam
  21. from federatedml.util import consts, LOGGER
  22. @deprecated_param("key_size", "raw_retrieval")
  23. class SecureInformationRetrievalParam(BaseParam):
  24. """
  25. Parameters
  26. ----------
  27. security_level: float, default 0.5
  28. security level, should set value in [0, 1]
  29. if security_level equals 0.0 means raw data retrieval
  30. oblivious_transfer_protocol: {"OT_Hauck"}
  31. OT type, only supports OT_Hauck
  32. commutative_encryption : {"CommutativeEncryptionPohligHellman"}
  33. the commutative encryption scheme used
  34. non_committing_encryption : {"aes"}
  35. the non-committing encryption scheme used
  36. dh_params
  37. params for Pohlig-Hellman Encryption
  38. key_size: int, value >= 1024
  39. the key length of the commutative cipher;
  40. note that this param will be deprecated in future, please specify key_length in PHParam instead.
  41. raw_retrieval: bool
  42. perform raw retrieval if raw_retrieval
  43. target_cols: str or list of str
  44. target cols to retrieve;
  45. any values not retrieved will be marked as "unretrieved",
  46. if target_cols is None, label will be retrieved, same behavior as in previous version
  47. default None
  48. """
  49. def __init__(self, security_level=0.5,
  50. oblivious_transfer_protocol=consts.OT_HAUCK,
  51. commutative_encryption=consts.CE_PH,
  52. non_committing_encryption=consts.AES,
  53. key_size=consts.DEFAULT_KEY_LENGTH,
  54. dh_params=DHParam(),
  55. raw_retrieval=False,
  56. target_cols=None):
  57. super(SecureInformationRetrievalParam, self).__init__()
  58. self.security_level = security_level
  59. self.oblivious_transfer_protocol = oblivious_transfer_protocol
  60. self.commutative_encryption = commutative_encryption
  61. self.non_committing_encryption = non_committing_encryption
  62. self.dh_params = dh_params
  63. self.key_size = key_size
  64. self.raw_retrieval = raw_retrieval
  65. self.target_cols = target_cols
  66. def check(self):
  67. descr = "secure information retrieval param's "
  68. self.check_decimal_float(self.security_level, descr + "security_level")
  69. self.oblivious_transfer_protocol = self.check_and_change_lower(self.oblivious_transfer_protocol,
  70. [consts.OT_HAUCK.lower()],
  71. descr + "oblivious_transfer_protocol")
  72. self.commutative_encryption = self.check_and_change_lower(self.commutative_encryption,
  73. [consts.CE_PH.lower()],
  74. descr + "commutative_encryption")
  75. self.non_committing_encryption = self.check_and_change_lower(self.non_committing_encryption,
  76. [consts.AES.lower()],
  77. descr + "non_committing_encryption")
  78. if self._warn_to_deprecate_param("key_size", descr, "dh_param's key_length"):
  79. self.dh_params.key_length = self.key_size
  80. self.dh_params.check()
  81. if self._warn_to_deprecate_param("raw_retrieval", descr, "dh_param's security_level = 0"):
  82. self.check_boolean(self.raw_retrieval, descr)
  83. self.target_cols = [] if self.target_cols is None else self.target_cols
  84. if not isinstance(self.target_cols, list):
  85. self.target_cols = [self.target_cols]
  86. for col in self.target_cols:
  87. self.check_string(col, descr + "target_cols")
  88. if len(self.target_cols) == 0:
  89. LOGGER.warning(f"Both 'target_cols' and 'target_indexes' are empty. Label will be retrieved.")