intersect_param.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import copy
  19. from pipeline.param.base_param import BaseParam
  20. from pipeline.param import consts
  21. DEFAULT_RANDOM_BIT = 128
  22. class EncodeParam(BaseParam):
  23. """
  24. Define the hash method for raw intersect method
  25. Parameters
  26. ----------
  27. salt: str
  28. the src data string will be str = str + salt, default by empty string
  29. encode_method: {"none", "md5", "sha1", "sha224", "sha256", "sha384", "sha512", "sm3"}
  30. the hash method of src data string, support md5, sha1, sha224, sha256, sha384, sha512, sm3, default by None
  31. base64: bool
  32. if True, the result of hash will be changed to base64, default by False
  33. """
  34. def __init__(self, salt='', encode_method='none', base64=False):
  35. super().__init__()
  36. self.salt = salt
  37. self.encode_method = encode_method
  38. self.base64 = base64
  39. def check(self):
  40. if type(self.salt).__name__ != "str":
  41. raise ValueError(
  42. "encode param's salt {} not supported, should be str type".format(
  43. self.salt))
  44. descr = "encode param's "
  45. self.encode_method = self.check_and_change_lower(self.encode_method,
  46. ["none", consts.MD5, consts.SHA1, consts.SHA224,
  47. consts.SHA256, consts.SHA384, consts.SHA512,
  48. consts.SM3],
  49. descr)
  50. if type(self.base64).__name__ != "bool":
  51. raise ValueError(
  52. "hash param's base64 {} not supported, should be bool type".format(self.base64))
  53. return True
  54. class RAWParam(BaseParam):
  55. """
  56. Specify parameters for raw intersect method
  57. Parameters
  58. ----------
  59. use_hash: bool
  60. whether to hash ids for raw intersect
  61. salt: str
  62. the src data string will be str = str + salt, default by empty string
  63. hash_method: str
  64. the hash method of src data string, support md5, sha1, sha224, sha256, sha384, sha512, sm3, default by None
  65. base64: bool
  66. if True, the result of hash will be changed to base64, default by False
  67. join_role: {"guest", "host"}
  68. role who joins ids, supports "guest" and "host" only and effective only for raw.
  69. If it is "guest", the host will send its ids to guest and find the intersection of
  70. ids in guest; if it is "host", the guest will send its ids to host. Default by "guest";
  71. """
  72. def __init__(self, use_hash=False, salt='', hash_method='none', base64=False, join_role=consts.GUEST):
  73. super().__init__()
  74. self.use_hash = use_hash
  75. self.salt = salt
  76. self.hash_method = hash_method
  77. self.base64 = base64
  78. self.join_role = join_role
  79. def check(self):
  80. descr = "raw param's "
  81. self.check_boolean(self.use_hash, f"{descr}use_hash")
  82. self.check_string(self.salt, f"{descr}salt")
  83. self.hash_method = self.check_and_change_lower(self.hash_method,
  84. ["none", consts.MD5, consts.SHA1, consts.SHA224,
  85. consts.SHA256, consts.SHA384, consts.SHA512,
  86. consts.SM3],
  87. f"{descr}hash_method")
  88. self.check_boolean(self.base64, f"{descr}base_64")
  89. self.join_role = self.check_and_change_lower(self.join_role, [consts.GUEST, consts.HOST], f"{descr}join_role")
  90. return True
  91. class RSAParam(BaseParam):
  92. """
  93. Specify parameters for RSA intersect method
  94. Parameters
  95. ----------
  96. salt: str
  97. the src data string will be str = str + salt, default ''
  98. hash_method: str
  99. the hash method of src data string, support sha256, sha384, sha512, sm3, default sha256
  100. final_hash_method: str
  101. the hash method of result data string, support md5, sha1, sha224, sha256, sha384, sha512, sm3, default sha256
  102. split_calculation: bool
  103. if True, Host & Guest split operations for faster performance, recommended on large data set
  104. random_base_fraction: positive float
  105. if not None, generate (fraction * public key id count) of r for encryption and reuse generated r;
  106. note that value greater than 0.99 will be taken as 1, and value less than 0.01 will be rounded up to 0.01
  107. key_length: int
  108. value >= 1024, bit count of rsa key, default 1024
  109. random_bit: positive int
  110. it will define the size of blinding factor in rsa algorithm, default 128
  111. """
  112. def __init__(self, salt='', hash_method='sha256', final_hash_method='sha256',
  113. split_calculation=False, random_base_fraction=None, key_length=consts.DEFAULT_KEY_LENGTH,
  114. random_bit=DEFAULT_RANDOM_BIT):
  115. super().__init__()
  116. self.salt = salt
  117. self.hash_method = hash_method
  118. self.final_hash_method = final_hash_method
  119. self.split_calculation = split_calculation
  120. self.random_base_fraction = random_base_fraction
  121. self.key_length = key_length
  122. self.random_bit = random_bit
  123. def check(self):
  124. descr = "rsa param's "
  125. self.check_string(self.salt, f"{descr}salt")
  126. self.hash_method = self.check_and_change_lower(self.hash_method,
  127. [consts.SHA256, consts.SHA384, consts.SHA512, consts.SM3],
  128. f"{descr}hash_method")
  129. self.final_hash_method = self.check_and_change_lower(self.final_hash_method,
  130. [consts.MD5, consts.SHA1, consts.SHA224,
  131. consts.SHA256, consts.SHA384, consts.SHA512,
  132. consts.SM3],
  133. f"{descr}final_hash_method")
  134. self.check_boolean(self.split_calculation, f"{descr}split_calculation")
  135. if self.random_base_fraction:
  136. self.check_positive_number(self.random_base_fraction, descr)
  137. self.check_decimal_float(self.random_base_fraction, f"{descr}random_base_fraction")
  138. self.check_positive_integer(self.key_length, f"{descr}key_length")
  139. if self.key_length < 1024:
  140. raise ValueError(f"key length must be >= 1024")
  141. self.check_positive_integer(self.random_bit, f"{descr}random_bit")
  142. return True
  143. class DHParam(BaseParam):
  144. """
  145. Define the hash method for DH intersect method
  146. Parameters
  147. ----------
  148. salt: str
  149. the src data string will be str = str + salt, default ''
  150. hash_method: str
  151. the hash method of src data string, support none, md5, sha1, sha 224, sha256, sha384, sha512, sm3, default sha256
  152. key_length: int, value >= 1024
  153. the key length of the commutative cipher p, default 1024
  154. """
  155. def __init__(self, salt='', hash_method='sha256', key_length=consts.DEFAULT_KEY_LENGTH):
  156. super().__init__()
  157. self.salt = salt
  158. self.hash_method = hash_method
  159. self.key_length = key_length
  160. def check(self):
  161. descr = "dh param's "
  162. self.check_string(self.salt, f"{descr}salt")
  163. self.hash_method = self.check_and_change_lower(self.hash_method,
  164. ["none", consts.MD5, consts.SHA1, consts.SHA224,
  165. consts.SHA256, consts.SHA384, consts.SHA512,
  166. consts.SM3],
  167. f"{descr}hash_method")
  168. self.check_positive_integer(self.key_length, f"{descr}key_length")
  169. if self.key_length < 1024:
  170. raise ValueError(f"key length must be >= 1024")
  171. return True
  172. class ECDHParam(BaseParam):
  173. """
  174. Define the hash method for ECDH intersect method
  175. Parameters
  176. ----------
  177. salt: str
  178. the src id will be str = str + salt, default ''
  179. hash_method: str
  180. the hash method of src id, support sha256, sha384, sha512, sm3, default sha256
  181. curve: str
  182. the name of curve, currently only support 'curve25519', which offers 128 bits of security
  183. """
  184. def __init__(self, salt='', hash_method='sha256', curve=consts.CURVE25519):
  185. super().__init__()
  186. self.salt = salt
  187. self.hash_method = hash_method
  188. self.curve = curve
  189. def check(self):
  190. descr = "ecdh param's "
  191. self.check_string(self.salt, f"{descr}salt")
  192. self.hash_method = self.check_and_change_lower(self.hash_method,
  193. [consts.SHA256, consts.SHA384, consts.SHA512,
  194. consts.SM3],
  195. f"{descr}hash_method")
  196. self.curve = self.check_and_change_lower(self.curve, [consts.CURVE25519], f"{descr}curve")
  197. return True
  198. class IntersectCache(BaseParam):
  199. def __init__(self, use_cache=False, id_type=consts.PHONE, encrypt_type=consts.SHA256):
  200. """
  201. Parameters
  202. ----------
  203. use_cache: whether to use cached ids; with ver1.7 and above, this param is ignored
  204. id_type: with ver1.7 and above, this param is ignored
  205. encrypt_type: with ver1.7 and above, this param is ignored
  206. """
  207. super().__init__()
  208. self.use_cache = use_cache
  209. self.id_type = id_type
  210. self.encrypt_type = encrypt_type
  211. def check(self):
  212. descr = "intersect_cache param's "
  213. # self.check_boolean(self.use_cache, f"{descr}use_cache")
  214. self.check_and_change_lower(self.id_type,
  215. [consts.PHONE, consts.IMEI],
  216. f"{descr}id_type")
  217. self.check_and_change_lower(self.encrypt_type,
  218. [consts.MD5, consts.SHA256],
  219. f"{descr}encrypt_type")
  220. class IntersectPreProcessParam(BaseParam):
  221. """
  222. Specify parameters for pre-processing and cardinality-only mode
  223. Parameters
  224. ----------
  225. false_positive_rate: float
  226. initial target false positive rate when creating Bloom Filter,
  227. must be <= 0.5, default 1e-3
  228. encrypt_method: str
  229. encrypt method for encrypting id when performing cardinality_only task,
  230. supports rsa only, default rsa;
  231. specify rsa parameter setting with RSAParam
  232. hash_method: str
  233. the hash method for inserting ids, support md5, sha1, sha 224, sha256, sha384, sha512, sm3,
  234. default sha256
  235. preprocess_method: str
  236. the hash method for encoding ids before insertion into filter, default sha256,
  237. only effective for preprocessing
  238. preprocess_salt: str
  239. salt to be appended to hash result by preprocess_method before insertion into filter,
  240. default '', only effective for preprocessing
  241. random_state: int
  242. seed for random salt generator when constructing hash functions,
  243. salt is appended to hash result by hash_method when performing insertion, default None
  244. filter_owner: str
  245. role that constructs filter, either guest or host, default guest,
  246. only effective for preprocessing
  247. """
  248. def __init__(self, false_positive_rate=1e-3, encrypt_method=consts.RSA, hash_method='sha256',
  249. preprocess_method='sha256', preprocess_salt='', random_state=None, filter_owner=consts.GUEST):
  250. super().__init__()
  251. self.false_positive_rate = false_positive_rate
  252. self.encrypt_method = encrypt_method
  253. self.hash_method = hash_method
  254. self.preprocess_method = preprocess_method
  255. self.preprocess_salt = preprocess_salt
  256. self.random_state = random_state
  257. self.filter_owner = filter_owner
  258. def check(self):
  259. descr = "intersect preprocess param's false_positive_rate "
  260. self.check_decimal_float(self.false_positive_rate, descr)
  261. self.check_positive_number(self.false_positive_rate, descr)
  262. if self.false_positive_rate > 0.5:
  263. raise ValueError(f"{descr} must be positive float no greater than 0.5")
  264. descr = "intersect preprocess param's encrypt_method "
  265. self.encrypt_method = self.check_and_change_lower(self.encrypt_method, [consts.RSA], descr)
  266. descr = "intersect preprocess param's random_state "
  267. if self.random_state:
  268. self.check_nonnegative_number(self.random_state, descr)
  269. descr = "intersect preprocess param's hash_method "
  270. self.hash_method = self.check_and_change_lower(self.hash_method,
  271. [consts.MD5, consts.SHA1, consts.SHA224,
  272. consts.SHA256, consts.SHA384, consts.SHA512,
  273. consts.SM3],
  274. descr)
  275. descr = "intersect preprocess param's preprocess_salt "
  276. self.check_string(self.preprocess_salt, descr)
  277. descr = "intersect preprocess param's preprocess_method "
  278. self.preprocess_method = self.check_and_change_lower(self.preprocess_method,
  279. [consts.MD5, consts.SHA1, consts.SHA224,
  280. consts.SHA256, consts.SHA384, consts.SHA512,
  281. consts.SM3],
  282. descr)
  283. descr = "intersect preprocess param's filter_owner "
  284. self.filter_owner = self.check_and_change_lower(self.filter_owner,
  285. [consts.GUEST, consts.HOST],
  286. descr)
  287. return True
  288. class IntersectParam(BaseParam):
  289. """
  290. Define the intersect method
  291. Parameters
  292. ----------
  293. intersect_method: str
  294. it supports 'rsa', 'raw', 'dh', default by 'rsa'
  295. random_bit: positive int
  296. it will define the size of blinding factor in rsa algorithm, default 128
  297. note that this param will be deprecated in future, please use random_bit in RSAParam instead
  298. sync_intersect_ids: bool
  299. In rsa, 'sync_intersect_ids' is True means guest or host will send intersect results to the others, and False will not.
  300. while in raw, 'sync_intersect_ids' is True means the role of "join_role" will send intersect results and the others will get them.
  301. Default by True.
  302. join_role: str
  303. role who joins ids, supports "guest" and "host" only and effective only for raw.
  304. If it is "guest", the host will send its ids to guest and find the intersection of
  305. ids in guest; if it is "host", the guest will send its ids to host. Default by "guest";
  306. note this param will be deprecated in future version, please use 'join_role' in raw_params instead
  307. only_output_key: bool
  308. if false, the results of intersection will include key and value which from input data; if true, it will just include key from input
  309. data and the value will be empty or filled by uniform string like "intersect_id"
  310. with_encode: bool
  311. if True, it will use hash method for intersect ids, effective for raw method only;
  312. note that this param will be deprecated in future version, please use 'use_hash' in raw_params;
  313. currently if this param is set to True,
  314. specification by 'encode_params' will be taken instead of 'raw_params'.
  315. encode_params: EncodeParam
  316. effective only when with_encode is True;
  317. this param will be deprecated in future version, use 'raw_params' in future implementation
  318. raw_params: RAWParam
  319. effective for raw method only
  320. rsa_params: RSAParam
  321. effective for rsa method only
  322. dh_params: DHParam
  323. effective for dh method only
  324. ecdh_params: ECDHParam
  325. effective for ecdh method only
  326. join_method: {'inner_join', 'left_join'}
  327. if 'left_join', participants will all include sample_id_generator's (imputed) ids in output,
  328. default 'inner_join'
  329. new_sample_id: bool
  330. whether to generate new id for sample_id_generator's ids,
  331. only effective when join_method is 'left_join' or when input data are instance with match id,
  332. default False
  333. sample_id_generator: str
  334. role whose ids are to be kept,
  335. effective only when join_method is 'left_join' or when input data are instance with match id,
  336. default 'guest'
  337. intersect_cache_param: IntersectCacheParam
  338. specification for cache generation,
  339. with ver1.7 and above, this param is ignored.
  340. run_cache: bool
  341. whether to store Host's encrypted ids, only valid when intersect method is 'rsa', 'dh', or 'ecdh', default False
  342. cardinality_only: bool
  343. whether to output intersection count(cardinality);
  344. if sync_cardinality is True, then sync cardinality count with host(s)
  345. cardinality_method: string
  346. specify which intersect method to use for coutning cardinality, default "ecdh";
  347. note that with "rsa", estimated cardinality will be produced;
  348. while "dh" method outputs exact cardinality, it only supports single-host task
  349. sync_cardinality: bool
  350. whether to sync cardinality with all participants, default False,
  351. only effective when cardinality_only set to True
  352. run_preprocess: bool
  353. whether to run preprocess process, default False
  354. intersect_preprocess_params: IntersectPreProcessParam
  355. used for preprocessing and cardinality_only mode
  356. repeated_id_process: bool
  357. if true, intersection will process the ids which can be repeatable;
  358. in ver 1.7 and above,repeated id process
  359. will be automatically applied to data with instance id, this param will be ignored
  360. repeated_id_owner: str
  361. which role has the repeated id; in ver 1.7 and above, this param is ignored
  362. allow_info_share: bool
  363. in ver 1.7 and above, this param is ignored
  364. info_owner: str
  365. in ver 1.7 and above, this param is ignored
  366. with_sample_id: bool
  367. data with sample id or not, default False; in ver 1.7 and above, this param is ignored
  368. """
  369. def __init__(self, intersect_method: str = consts.RSA, random_bit=DEFAULT_RANDOM_BIT, sync_intersect_ids=True,
  370. join_role=consts.GUEST, only_output_key: bool = False,
  371. with_encode=False, encode_params=EncodeParam(),
  372. raw_params=RAWParam(), rsa_params=RSAParam(), dh_params=DHParam(), ecdh_params=ECDHParam(),
  373. join_method=consts.INNER_JOIN, new_sample_id: bool = False, sample_id_generator=consts.GUEST,
  374. intersect_cache_param=IntersectCache(), run_cache: bool = False,
  375. cardinality_only: bool = False, sync_cardinality: bool = False, cardinality_method=consts.ECDH,
  376. run_preprocess: bool = False,
  377. intersect_preprocess_params=IntersectPreProcessParam(),
  378. repeated_id_process=False, repeated_id_owner=consts.GUEST,
  379. with_sample_id=False, allow_info_share: bool = False, info_owner=consts.GUEST):
  380. super().__init__()
  381. self.intersect_method = intersect_method
  382. self.random_bit = random_bit
  383. self.sync_intersect_ids = sync_intersect_ids
  384. self.join_role = join_role
  385. self.with_encode = with_encode
  386. self.encode_params = copy.deepcopy(encode_params)
  387. self.raw_params = copy.deepcopy(raw_params)
  388. self.rsa_params = copy.deepcopy(rsa_params)
  389. self.only_output_key = only_output_key
  390. self.sample_id_generator = sample_id_generator
  391. self.intersect_cache_param = copy.deepcopy(intersect_cache_param)
  392. self.run_cache = run_cache
  393. self.repeated_id_process = repeated_id_process
  394. self.repeated_id_owner = repeated_id_owner
  395. self.allow_info_share = allow_info_share
  396. self.info_owner = info_owner
  397. self.with_sample_id = with_sample_id
  398. self.join_method = join_method
  399. self.new_sample_id = new_sample_id
  400. self.dh_params = copy.deepcopy(dh_params)
  401. self.cardinality_only = cardinality_only
  402. self.sync_cardinality = sync_cardinality
  403. self.cardinality_method = cardinality_method
  404. self.run_preprocess = run_preprocess
  405. self.intersect_preprocess_params = copy.deepcopy(intersect_preprocess_params)
  406. self.ecdh_params = copy.deepcopy(ecdh_params)
  407. def check(self):
  408. descr = "intersect param's "
  409. self.intersect_method = self.check_and_change_lower(self.intersect_method,
  410. [consts.RSA, consts.RAW, consts.DH, consts.ECDH],
  411. f"{descr}intersect_method")
  412. self.check_positive_integer(self.random_bit, f"{descr}random_bit")
  413. self.check_boolean(self.sync_intersect_ids, f"{descr}intersect_ids")
  414. self.join_role = self.check_and_change_lower(self.join_role,
  415. [consts.GUEST, consts.HOST],
  416. f"{descr}join_role")
  417. self.check_boolean(self.with_encode, f"{descr}with_encode")
  418. self.check_boolean(self.only_output_key, f"{descr}only_output_key")
  419. self.join_method = self.check_and_change_lower(self.join_method, [consts.INNER_JOIN, consts.LEFT_JOIN],
  420. f"{descr}join_method")
  421. self.check_boolean(self.new_sample_id, f"{descr}new_sample_id")
  422. self.sample_id_generator = self.check_and_change_lower(self.sample_id_generator,
  423. [consts.GUEST, consts.HOST],
  424. f"{descr}sample_id_generator")
  425. if self.join_method == consts.LEFT_JOIN:
  426. if not self.sync_intersect_ids:
  427. raise ValueError(f"Cannot perform left join without sync intersect ids")
  428. self.check_boolean(self.run_cache, f"{descr} run_cache")
  429. self.encode_params.check()
  430. self.raw_params.check()
  431. self.rsa_params.check()
  432. self.dh_params.check()
  433. self.ecdh_params.check()
  434. self.check_boolean(self.cardinality_only, f"{descr}cardinality_only")
  435. self.check_boolean(self.sync_cardinality, f"{descr}sync_cardinality")
  436. self.check_boolean(self.run_preprocess, f"{descr}run_preprocess")
  437. self.intersect_preprocess_params.check()
  438. if self.cardinality_only:
  439. if self.cardinality_method not in [consts.RSA, consts.DH, consts.ECDH]:
  440. raise ValueError(f"cardinality-only mode only support rsa, dh, ecdh.")
  441. if self.cardinality_method == consts.RSA and self.rsa_params.split_calculation:
  442. raise ValueError(f"cardinality-only mode only supports unified calculation.")
  443. if self.run_preprocess:
  444. if self.intersect_preprocess_params.false_positive_rate < 0.01:
  445. raise ValueError(f"for preprocessing ids, false_positive_rate must be no less than 0.01")
  446. if self.cardinality_only:
  447. raise ValueError(f"cardinality_only mode cannot run preprocessing.")
  448. if self.run_cache:
  449. if self.intersect_method not in [consts.RSA, consts.DH, consts.ECDH]:
  450. raise ValueError(f"Only rsa, dh, ecdh method supports cache.")
  451. if self.intersect_method == consts.RSA and self.rsa_params.split_calculation:
  452. raise ValueError(f"RSA split_calculation does not support cache.")
  453. if self.cardinality_only:
  454. raise ValueError(f"Cache is not available for cardinality_only mode.")
  455. if self.run_preprocess:
  456. raise ValueError(f"Preprocessing does not support cache.")
  457. return True