123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from pipeline.param.base_param import BaseParam
- from pipeline.param.intersect_param import DHParam
- from pipeline.param import consts
- class SecureInformationRetrievalParam(BaseParam):
- """
- Parameters
- ----------
- security_level: float, default 0.5
- security level, should set value in [0, 1]
- if security_level equals 0.0 means raw data retrieval
- oblivious_transfer_protocol: {"OT_Hauck"}
- OT type, only supports OT_Hauck
- commutative_encryption : {"CommutativeEncryptionPohligHellman"}
- the commutative encryption scheme used
- non_committing_encryption : {"aes"}
- the non-committing encryption scheme used
- dh_params
- params for Pohlig-Hellman Encryption
- key_size: int, value >= 1024
- the key length of the commutative cipher;
- note that this param will be deprecated in future, please specify key_length in PHParam instead.
- raw_retrieval: bool
- perform raw retrieval if raw_retrieval
- target_cols: str or list of str
- target cols to retrieve;
- any values not retrieved will be marked as "unretrieved",
- if target_cols is None, label will be retrieved, same behavior as in previous version
- default None
- """
- def __init__(self, security_level=0.5,
- oblivious_transfer_protocol=consts.OT_HAUCK,
- commutative_encryption=consts.CE_PH,
- non_committing_encryption=consts.AES,
- key_size=consts.DEFAULT_KEY_LENGTH,
- dh_params=DHParam(),
- raw_retrieval=False,
- target_cols=None):
- super(SecureInformationRetrievalParam, self).__init__()
- self.security_level = security_level
- self.oblivious_transfer_protocol = oblivious_transfer_protocol
- self.commutative_encryption = commutative_encryption
- self.non_committing_encryption = non_committing_encryption
- self.dh_params = dh_params
- self.key_size = key_size
- self.raw_retrieval = raw_retrieval
- self.target_cols = [] if target_cols is None else target_cols
- def check(self):
- descr = "secure information retrieval param's "
- self.check_decimal_float(self.security_level, descr + "security_level")
- self.oblivious_transfer_protocol = self.check_and_change_lower(self.oblivious_transfer_protocol,
- [consts.OT_HAUCK.lower()],
- descr + "oblivious_transfer_protocol")
- self.commutative_encryption = self.check_and_change_lower(self.commutative_encryption,
- [consts.CE_PH.lower()],
- descr + "commutative_encryption")
- self.non_committing_encryption = self.check_and_change_lower(self.non_committing_encryption,
- [consts.AES.lower()],
- descr + "non_committing_encryption")
- self.dh_params.check()
- if self.key_size:
- self.check_positive_integer(self.key_size, descr + "key_size")
- if self.key_size < 1024:
- raise ValueError(f"key size must be >= 1024")
- self.check_boolean(self.raw_retrieval, descr)
- if not isinstance(self.target_cols, list):
- self.target_cols = [self.target_cols]
- for col in self.target_cols:
- self.check_string(col, descr + "target_cols")
|