sir_param.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from pipeline.param.base_param import BaseParam
  19. from pipeline.param.intersect_param import DHParam
  20. from pipeline.param import consts
  21. class SecureInformationRetrievalParam(BaseParam):
  22. """
  23. Parameters
  24. ----------
  25. security_level: float, default 0.5
  26. security level, should set value in [0, 1]
  27. if security_level equals 0.0 means raw data retrieval
  28. oblivious_transfer_protocol: {"OT_Hauck"}
  29. OT type, only supports OT_Hauck
  30. commutative_encryption : {"CommutativeEncryptionPohligHellman"}
  31. the commutative encryption scheme used
  32. non_committing_encryption : {"aes"}
  33. the non-committing encryption scheme used
  34. dh_params
  35. params for Pohlig-Hellman Encryption
  36. key_size: int, value >= 1024
  37. the key length of the commutative cipher;
  38. note that this param will be deprecated in future, please specify key_length in PHParam instead.
  39. raw_retrieval: bool
  40. perform raw retrieval if raw_retrieval
  41. target_cols: str or list of str
  42. target cols to retrieve;
  43. any values not retrieved will be marked as "unretrieved",
  44. if target_cols is None, label will be retrieved, same behavior as in previous version
  45. default None
  46. """
  47. def __init__(self, security_level=0.5,
  48. oblivious_transfer_protocol=consts.OT_HAUCK,
  49. commutative_encryption=consts.CE_PH,
  50. non_committing_encryption=consts.AES,
  51. key_size=consts.DEFAULT_KEY_LENGTH,
  52. dh_params=DHParam(),
  53. raw_retrieval=False,
  54. target_cols=None):
  55. super(SecureInformationRetrievalParam, self).__init__()
  56. self.security_level = security_level
  57. self.oblivious_transfer_protocol = oblivious_transfer_protocol
  58. self.commutative_encryption = commutative_encryption
  59. self.non_committing_encryption = non_committing_encryption
  60. self.dh_params = dh_params
  61. self.key_size = key_size
  62. self.raw_retrieval = raw_retrieval
  63. self.target_cols = [] if target_cols is None else target_cols
  64. def check(self):
  65. descr = "secure information retrieval param's "
  66. self.check_decimal_float(self.security_level, descr + "security_level")
  67. self.oblivious_transfer_protocol = self.check_and_change_lower(self.oblivious_transfer_protocol,
  68. [consts.OT_HAUCK.lower()],
  69. descr + "oblivious_transfer_protocol")
  70. self.commutative_encryption = self.check_and_change_lower(self.commutative_encryption,
  71. [consts.CE_PH.lower()],
  72. descr + "commutative_encryption")
  73. self.non_committing_encryption = self.check_and_change_lower(self.non_committing_encryption,
  74. [consts.AES.lower()],
  75. descr + "non_committing_encryption")
  76. self.dh_params.check()
  77. if self.key_size:
  78. self.check_positive_integer(self.key_size, descr + "key_size")
  79. if self.key_size < 1024:
  80. raise ValueError(f"key size must be >= 1024")
  81. self.check_boolean(self.raw_retrieval, descr)
  82. if not isinstance(self.target_cols, list):
  83. self.target_cols = [self.target_cols]
  84. for col in self.target_cols:
  85. self.check_string(col, descr + "target_cols")