intersect_model.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import numpy as np
  17. from fate_arch.common.base_utils import fate_uuid
  18. from federatedml.feature.instance import Instance
  19. from federatedml.model_base import Metric, MetricMeta
  20. from federatedml.model_base import ModelBase
  21. from federatedml.param.intersect_param import IntersectParam
  22. from federatedml.secureprotol.hash.hash_factory import Hash
  23. from federatedml.statistic import data_overview
  24. from federatedml.statistic.intersect import RawIntersectionHost, RawIntersectionGuest, RsaIntersectionHost, \
  25. RsaIntersectionGuest, DhIntersectionGuest, DhIntersectionHost, EcdhIntersectionHost, EcdhIntersectionGuest
  26. from federatedml.statistic.intersect.match_id_process import MatchIDIntersect
  27. from federatedml.transfer_variable.transfer_class.intersection_func_transfer_variable import \
  28. IntersectionFuncTransferVariable
  29. from federatedml.util import consts, LOGGER, data_format_preprocess
  30. class IntersectModelBase(ModelBase):
  31. def __init__(self):
  32. super().__init__()
  33. self.intersection_obj = None
  34. self.proc_obj = None
  35. # self.intersect_num = -1
  36. self.intersect_rate = -1
  37. self.unmatched_num = -1
  38. self.unmatched_rate = -1
  39. self.intersect_ids = None
  40. self.metric_name = "intersection"
  41. self.metric_namespace = "train"
  42. self.metric_type = "INTERSECTION"
  43. self.model_param_name = "IntersectModelParam"
  44. self.model_meta_name = "IntersectModelMeta"
  45. self.model_param = IntersectParam()
  46. self.use_match_id_process = False
  47. self.role = None
  48. self.intersect_method = None
  49. self.match_id_num = None
  50. self.match_id_intersect_num = -1
  51. self.recovered_num = -1
  52. self.guest_party_id = None
  53. self.host_party_id = None
  54. self.host_party_id_list = None
  55. self.transfer_variable = IntersectionFuncTransferVariable()
  56. def _init_model(self, params):
  57. self.model_param = params
  58. self.intersect_preprocess_params = params.intersect_preprocess_params
  59. def init_intersect_method(self):
  60. if self.model_param.cardinality_only:
  61. self.intersect_method = self.model_param.cardinality_method
  62. else:
  63. self.intersect_method = self.model_param.intersect_method
  64. LOGGER.info("Using {} intersection, role is {}".format(self.intersect_method, self.role))
  65. self.host_party_id_list = self.component_properties.host_party_idlist
  66. self.guest_party_id = self.component_properties.guest_partyid
  67. if self.role not in [consts.HOST, consts.GUEST]:
  68. raise ValueError("role {} is not support".format(self.role))
  69. def get_model_summary(self):
  70. return {"intersect_num": self.match_id_intersect_num, "intersect_rate": self.intersect_rate,
  71. "cardinality_only": self.intersection_obj.cardinality_only,
  72. "unique_id_num": self.match_id_num}
  73. def sync_use_match_id(self):
  74. raise NotImplementedError(f"Should not be called here.")
  75. def __share_info(self, data):
  76. LOGGER.info("Start to share information with another role")
  77. info_share = self.transfer_variable.info_share_from_guest if self.model_param.info_owner == consts.GUEST else \
  78. self.transfer_variable.info_share_from_host
  79. party_role = consts.GUEST if self.model_param.info_owner == consts.HOST else consts.HOST
  80. if self.role == self.model_param.info_owner:
  81. if data.schema.get('header') is not None:
  82. try:
  83. share_info_col_idx = data.schema.get('header').index(consts.SHARE_INFO_COL_NAME)
  84. one_data = data.first()
  85. if isinstance(one_data[1], Instance):
  86. share_data = data.join(self.intersect_ids, lambda d, i: [d.features[share_info_col_idx]])
  87. else:
  88. share_data = data.join(self.intersect_ids, lambda d, i: [d[share_info_col_idx]])
  89. info_share.remote(share_data,
  90. role=party_role,
  91. idx=-1)
  92. LOGGER.info("Remote share information to {}".format(party_role))
  93. except Exception as e:
  94. LOGGER.warning("Something unexpected:{}, share a empty information to {}".format(e, party_role))
  95. share_data = self.intersect_ids.mapValues(lambda v: ['null'])
  96. info_share.remote(share_data,
  97. role=party_role,
  98. idx=-1)
  99. else:
  100. raise ValueError(
  101. "'allow_info_share' is true, and 'info_owner' is {}, but can not get header in data, information sharing not done".format(
  102. self.model_param.info_owner))
  103. else:
  104. self.intersect_ids = info_share.get(idx=0)
  105. self.intersect_ids.schema['header'] = [consts.SHARE_INFO_COL_NAME]
  106. LOGGER.info(
  107. "Get share information from {}, header:{}".format(self.model_param.info_owner, self.intersect_ids))
  108. return self.intersect_ids
  109. def __sync_join_id(self, data, intersect_data):
  110. LOGGER.debug(f"data count: {data.count()}")
  111. LOGGER.debug(f"intersect_data count: {intersect_data.count()}")
  112. if self.model_param.sample_id_generator == consts.GUEST:
  113. sync_join_id = self.transfer_variable.join_id_from_guest
  114. else:
  115. sync_join_id = self.transfer_variable.join_id_from_host
  116. if self.role == self.model_param.sample_id_generator:
  117. join_data = data.subtractByKey(intersect_data)
  118. # LOGGER.debug(f"join_data count: {join_data.count()}")
  119. if self.model_param.new_sample_id:
  120. if self.model_param.only_output_key:
  121. join_data = join_data.map(lambda k, v: (fate_uuid(), None))
  122. join_id = join_data
  123. else:
  124. join_data = join_data.map(lambda k, v: (fate_uuid(), v))
  125. join_id = join_data.mapValues(lambda v: None)
  126. sync_join_id.remote(join_id)
  127. result_data = intersect_data.union(join_data)
  128. else:
  129. join_id = join_data.map(lambda k, v: (k, None))
  130. result_data = data
  131. if self.model_param.only_output_key:
  132. result_data = data.mapValues(lambda v: None)
  133. sync_join_id.remote(join_id)
  134. else:
  135. join_id = sync_join_id.get(idx=0)
  136. # LOGGER.debug(f"received join_id count: {join_id.count()}")
  137. join_data = join_id
  138. if not self.model_param.only_output_key:
  139. feature_shape = data.first()[1].features.shape[0]
  140. def _generate_nan_instance():
  141. filler = np.empty((feature_shape,))
  142. filler.fill(np.nan)
  143. return filler
  144. join_data = join_id.mapValues(lambda v: Instance(features=_generate_nan_instance()))
  145. result_data = intersect_data.union(join_data)
  146. LOGGER.debug(f"result data count: {result_data.count()}")
  147. return result_data
  148. def callback(self):
  149. meta_info = {"intersect_method": self.intersect_method,
  150. "join_method": self.model_param.join_method}
  151. if self.use_match_id_process:
  152. self.callback_metric(metric_name=self.metric_name,
  153. metric_namespace=self.metric_namespace,
  154. metric_data=[Metric("intersect_count", self.match_id_intersect_num),
  155. Metric("input_match_id_count", self.match_id_num),
  156. Metric("intersect_rate", self.intersect_rate),
  157. Metric("unmatched_count", self.unmatched_num),
  158. Metric("unmatched_rate", self.unmatched_rate),
  159. Metric("intersect_sample_id_count", self.recovered_num)])
  160. else:
  161. self.callback_metric(metric_name=self.metric_name,
  162. metric_namespace=self.metric_namespace,
  163. metric_data=[Metric("intersect_count", self.match_id_intersect_num),
  164. Metric("input_match_id_count", self.match_id_num),
  165. Metric("intersect_rate", self.intersect_rate),
  166. Metric("unmatched_count", self.unmatched_num),
  167. Metric("unmatched_rate", self.unmatched_rate)])
  168. self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
  169. metric_name=self.metric_name,
  170. metric_meta=MetricMeta(name=self.metric_name,
  171. metric_type=self.metric_type,
  172. extra_metas=meta_info)
  173. )
  174. def callback_cache_meta(self, intersect_meta):
  175. metric_name = f"{self.metric_name}_cache_meta"
  176. self.tracker.set_metric_meta(metric_namespace=self.metric_namespace,
  177. metric_name=metric_name,
  178. metric_meta=MetricMeta(name=f"{self.metric_name}_cache_meta",
  179. metric_type=self.metric_type,
  180. extra_metas=intersect_meta)
  181. )
  182. def fit(self, data):
  183. if self.component_properties.caches:
  184. LOGGER.info(f"Cache provided, will enter intersect online process.")
  185. return self.intersect_online_process(data, self.component_properties.caches)
  186. self.init_intersect_method()
  187. if data_overview.check_with_inst_id(data):
  188. self.use_match_id_process = True
  189. LOGGER.info(f"use match_id_process")
  190. self.sync_use_match_id()
  191. if self.use_match_id_process:
  192. if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
  193. raise ValueError("While multi-host, sample_id_generator should be guest.")
  194. if self.intersect_method == consts.RAW:
  195. if self.model_param.sample_id_generator != self.intersection_obj.join_role:
  196. raise ValueError(f"When using raw intersect with match id process,"
  197. f"'join_role' should be same role as 'sample_id_generator'")
  198. else:
  199. if not self.model_param.sync_intersect_ids:
  200. if self.model_param.sample_id_generator != consts.GUEST:
  201. self.model_param.sample_id_generator = consts.GUEST
  202. LOGGER.warning(f"when not sync_intersect_ids with match id process,"
  203. f"sample_id_generator is set to Guest")
  204. self.proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
  205. self.proc_obj.new_sample_id = self.model_param.new_sample_id
  206. if data_overview.check_with_inst_id(data) or self.model_param.with_sample_id:
  207. self.proc_obj.use_sample_id()
  208. match_data = self.proc_obj.recover(data=data)
  209. self.match_id_num = match_data.count()
  210. if self.intersection_obj.run_cache:
  211. self.cache_output = self.intersection_obj.generate_cache(match_data)
  212. intersect_meta = self.intersection_obj.get_intersect_method_meta()
  213. self.callback_cache_meta(intersect_meta)
  214. return
  215. if self.intersection_obj.cardinality_only:
  216. self.intersection_obj.run_cardinality(match_data)
  217. else:
  218. intersect_data = match_data
  219. if self.model_param.run_preprocess:
  220. intersect_data = self.run_preprocess(match_data)
  221. self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)
  222. if self.intersect_ids:
  223. self.match_id_intersect_num = self.intersect_ids.count()
  224. else:
  225. if self.model_param.join_method == consts.LEFT_JOIN:
  226. raise ValueError(f"Only data with match_id may apply left_join method. Please check input data format")
  227. self.match_id_num = data.count()
  228. if self.intersection_obj.run_cache:
  229. self.cache_output = self.intersection_obj.generate_cache(data)
  230. intersect_meta = self.intersection_obj.get_intersect_method_meta()
  231. # LOGGER.debug(f"callback intersect meta is: {intersect_meta}")
  232. self.callback_cache_meta(intersect_meta)
  233. return
  234. if self.intersection_obj.cardinality_only:
  235. self.intersection_obj.run_cardinality(data)
  236. else:
  237. intersect_data = data
  238. if self.model_param.run_preprocess:
  239. intersect_data = self.run_preprocess(data)
  240. self.intersect_ids = self.intersection_obj.run_intersect(intersect_data)
  241. if self.intersect_ids:
  242. self.match_id_intersect_num = self.intersect_ids.count()
  243. if self.intersection_obj.cardinality_only:
  244. if self.intersection_obj.intersect_num is not None:
  245. # data_count = data.count()
  246. self.match_id_intersect_num = self.intersection_obj.intersect_num
  247. self.intersect_rate = self.match_id_intersect_num / self.match_id_num
  248. self.unmatched_num = self.match_id_num - self.match_id_intersect_num
  249. self.unmatched_rate = 1 - self.intersect_rate
  250. self.set_summary(self.get_model_summary())
  251. self.callback()
  252. return None
  253. if self.use_match_id_process:
  254. if self.model_param.sync_intersect_ids:
  255. self.intersect_ids = self.proc_obj.expand(self.intersect_ids, match_data=match_data)
  256. else:
  257. # self.intersect_ids = match_data
  258. self.intersect_ids = self.proc_obj.expand(self.intersect_ids,
  259. match_data=match_data,
  260. owner_only=True)
  261. if self.intersect_ids:
  262. self.recovered_num = self.intersect_ids.count()
  263. if self.model_param.only_output_key and self.intersect_ids:
  264. self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
  265. # self.intersect_ids.schema = {"match_id_name": data.schema["match_id_name"],
  266. # "sid": data.schema.get("sid")}
  267. self.intersect_ids.schema = data_format_preprocess.DataFormatPreProcess.clean_header(data.schema)
  268. LOGGER.info("Finish intersection")
  269. if self.intersect_ids:
  270. self.intersect_rate = self.match_id_intersect_num / self.match_id_num
  271. self.unmatched_num = self.match_id_num - self.match_id_intersect_num
  272. self.unmatched_rate = 1 - self.intersect_rate
  273. self.set_summary(self.get_model_summary())
  274. self.callback()
  275. result_data = self.intersect_ids
  276. if not self.use_match_id_process and result_data:
  277. if self.intersection_obj.only_output_key:
  278. # result_data.schema = {"sid": data.schema.get("sid")}
  279. result_data.schema = data_format_preprocess.DataFormatPreProcess.clean_header(data.schema)
  280. LOGGER.debug(f"non-match-id & only_output_key, add sid to schema")
  281. else:
  282. result_data = self.intersection_obj.get_value_from_data(result_data, data)
  283. LOGGER.debug(f"not only_output_key, restore instance value")
  284. if self.model_param.join_method == consts.LEFT_JOIN:
  285. result_data = self.__sync_join_id(data, self.intersect_ids)
  286. result_data.schema = self.intersect_ids.schema
  287. return result_data
  288. def check_consistency(self):
  289. pass
  290. def load_intersect_meta(self, intersect_meta):
  291. if self.model_param.intersect_method != intersect_meta.get("intersect_method"):
  292. raise ValueError(f"Current intersect method must match to cache record.")
  293. if self.model_param.intersect_method == consts.RSA:
  294. self.model_param.rsa_params.hash_method = intersect_meta["hash_method"]
  295. self.model_param.rsa_params.final_hash_method = intersect_meta["final_hash_method"]
  296. self.model_param.rsa_params.salt = intersect_meta["salt"]
  297. self.model_param.rsa_params.random_bit = intersect_meta["random_bit"]
  298. elif self.model_param.intersect_method == consts.DH:
  299. self.model_param.dh_params.hash_method = intersect_meta["hash_method"]
  300. self.model_param.dh_params.salt = intersect_meta["salt"]
  301. elif self.model_param.intersect_method == consts.ECDH:
  302. self.model_param.ecdh_params.hash_method = intersect_meta["hash_method"]
  303. self.model_param.ecdh_params.salt = intersect_meta["salt"]
  304. self.model_param.ecdh_params.curve = intersect_meta["curve"]
  305. else:
  306. raise ValueError(f"{self.model_param.intersect_method} does not support cache.")
  307. def make_filter_process(self, data_instances, hash_operator):
  308. raise NotImplementedError("This method should not be called here")
  309. def get_filter_process(self, data_instances, hash_operator):
  310. raise NotImplementedError("This method should not be called here")
  311. def run_preprocess(self, data_instances):
  312. preprocess_hash_operator = Hash(self.model_param.intersect_preprocess_params.preprocess_method, False)
  313. if self.role == self.model_param.intersect_preprocess_params.filter_owner:
  314. data = self.make_filter_process(data_instances, preprocess_hash_operator)
  315. else:
  316. LOGGER.debug(f"before preprocess, data count: {data_instances.count()}")
  317. data = self.get_filter_process(data_instances, preprocess_hash_operator)
  318. LOGGER.debug(f"after preprocess, data count: {data.count()}")
  319. return data
  320. def intersect_online_process(self, data_inst, caches):
  321. # LOGGER.debug(f"caches is: {caches}")
  322. cache_data, cache_meta = list(caches.values())[0]
  323. intersect_meta = list(cache_meta.values())[0]["intersect_meta"]
  324. # LOGGER.debug(f"intersect_meta is: {intersect_meta}")
  325. self.callback_cache_meta(intersect_meta)
  326. self.load_intersect_meta(intersect_meta)
  327. self.init_intersect_method()
  328. self.intersection_obj.load_intersect_key(cache_meta)
  329. if data_overview.check_with_inst_id(data_inst):
  330. self.use_match_id_process = True
  331. LOGGER.info(f"use match_id_process")
  332. self.sync_use_match_id()
  333. intersect_data = data_inst
  334. self.match_id_num = data_inst.count()
  335. if self.use_match_id_process:
  336. if len(self.host_party_id_list) > 1 and self.model_param.sample_id_generator != consts.GUEST:
  337. raise ValueError("While multi-host, sample_id_generator should be guest.")
  338. if self.intersect_method == consts.RAW:
  339. if self.model_param.sample_id_generator != self.intersection_obj.join_role:
  340. raise ValueError(f"When using raw intersect with match id process,"
  341. f"'join_role' should be same role as 'sample_id_generator'")
  342. else:
  343. if not self.model_param.sync_intersect_ids:
  344. if self.model_param.sample_id_generator != consts.GUEST:
  345. self.model_param.sample_id_generator = consts.GUEST
  346. LOGGER.warning(f"when not sync_intersect_ids with match id process,"
  347. f"sample_id_generator is set to Guest")
  348. proc_obj = MatchIDIntersect(sample_id_generator=self.model_param.sample_id_generator, role=self.role)
  349. proc_obj.new_sample_id = self.model_param.new_sample_id
  350. if data_overview.check_with_inst_id(data_inst) or self.model_param.with_sample_id:
  351. proc_obj.use_sample_id()
  352. match_data = proc_obj.recover(data=data_inst)
  353. intersect_data = match_data
  354. self.match_id_num = match_data.count()
  355. if self.role == consts.HOST:
  356. cache_id = cache_meta[str(self.guest_party_id)].get("cache_id")
  357. self.transfer_variable.cache_id.remote(cache_id, role=consts.GUEST, idx=0)
  358. guest_cache_id = self.transfer_variable.cache_id.get(role=consts.GUEST, idx=0)
  359. self.match_id_num = list(cache_data.values())[0].count()
  360. if guest_cache_id != cache_id:
  361. raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
  362. elif self.role == consts.GUEST:
  363. for i, party_id in enumerate(self.host_party_id_list):
  364. cache_id = cache_meta[str(party_id)].get("cache_id")
  365. self.transfer_variable.cache_id.remote(cache_id,
  366. role=consts.HOST,
  367. idx=i)
  368. host_cache_id = self.transfer_variable.cache_id.get(role=consts.HOST, idx=i)
  369. if host_cache_id != cache_id:
  370. raise ValueError(f"cache_id check failed. cache_id from host & guest must match.")
  371. else:
  372. raise ValueError(f"Role {self.role} cannot run intersection transform.")
  373. self.intersect_ids = self.intersection_obj.run_cache_intersect(intersect_data, cache_data)
  374. self.match_id_intersect_num = self.intersect_ids.count()
  375. if self.use_match_id_process:
  376. if not self.model_param.sync_intersect_ids:
  377. self.intersect_ids = proc_obj.expand(self.intersect_ids,
  378. match_data=match_data,
  379. owner_only=True)
  380. else:
  381. self.intersect_ids = proc_obj.expand(self.intersect_ids, match_data=match_data)
  382. if self.intersect_ids:
  383. self.recovered_num = self.intersect_ids.count()
  384. if self.intersect_ids and self.model_param.only_output_key:
  385. self.intersect_ids = self.intersect_ids.mapValues(lambda v: Instance(inst_id=v.inst_id))
  386. # self.intersect_ids.schema = {"match_id_name": data_inst.schema["match_id_name"],
  387. # "sid": data_inst.schema.get("sid")}
  388. self.intersect_ids.schema = data_format_preprocess.DataFormatPreProcess.clean_header(data_inst.schema)
  389. LOGGER.info("Finish intersection")
  390. if self.intersect_ids:
  391. self.intersect_rate = self.match_id_intersect_num / self.match_id_num
  392. self.unmatched_num = self.match_id_num - self.match_id_intersect_num
  393. self.unmatched_rate = 1 - self.intersect_rate
  394. self.set_summary(self.get_model_summary())
  395. self.callback()
  396. result_data = self.intersect_ids
  397. if not self.use_match_id_process:
  398. if not self.intersection_obj.only_output_key and result_data:
  399. result_data = self.intersection_obj.get_value_from_data(result_data, data_inst)
  400. self.intersect_ids.schema = result_data.schema
  401. LOGGER.debug(f"not only_output_key, restore value called")
  402. if self.intersection_obj.only_output_key and result_data:
  403. # schema = {"sid": data_inst.schema.get("sid")}
  404. schema = data_format_preprocess.DataFormatPreProcess.clean_header(data_inst.schema)
  405. result_data = result_data.mapValues(lambda v: None)
  406. result_data.schema = schema
  407. self.intersect_ids.schema = schema
  408. if self.model_param.join_method == consts.LEFT_JOIN:
  409. result_data = self.__sync_join_id(data_inst, self.intersect_ids)
  410. result_data.schema = self.intersect_ids.schema
  411. return result_data
  412. class IntersectHost(IntersectModelBase):
  413. def __init__(self):
  414. super().__init__()
  415. self.role = consts.HOST
  416. def init_intersect_method(self):
  417. super().init_intersect_method()
  418. self.host_party_id = self.component_properties.local_partyid
  419. if self.intersect_method == consts.RSA:
  420. self.intersection_obj = RsaIntersectionHost()
  421. elif self.intersect_method == consts.RAW:
  422. self.intersection_obj = RawIntersectionHost()
  423. self.intersection_obj.tracker = self.tracker
  424. self.intersection_obj.task_version_id = self.task_version_id
  425. elif self.intersect_method == consts.DH:
  426. self.intersection_obj = DhIntersectionHost()
  427. elif self.intersect_method == consts.ECDH:
  428. self.intersection_obj = EcdhIntersectionHost()
  429. else:
  430. raise ValueError("intersect_method {} is not support yet".format(self.intersect_method))
  431. self.intersection_obj.host_party_id = self.host_party_id
  432. self.intersection_obj.guest_party_id = self.guest_party_id
  433. self.intersection_obj.host_party_id_list = self.host_party_id_list
  434. self.intersection_obj.load_params(self.model_param)
  435. self.model_param = self.intersection_obj.model_param
  436. def sync_use_match_id(self):
  437. self.transfer_variable.use_match_id.remote(self.use_match_id_process, role=consts.GUEST, idx=-1)
  438. LOGGER.info(f"sync use_match_id flag: {self.use_match_id_process} with Guest")
  439. def make_filter_process(self, data_instances, hash_operator):
  440. filter = self.intersection_obj.construct_filter(data_instances,
  441. self.intersect_preprocess_params.false_positive_rate,
  442. self.intersect_preprocess_params.hash_method,
  443. self.intersect_preprocess_params.random_state,
  444. hash_operator,
  445. self.intersect_preprocess_params.preprocess_salt)
  446. self.transfer_variable.intersect_filter_from_host.remote(filter, role=consts.GUEST, idx=0)
  447. LOGGER.debug(f"filter sent to guest")
  448. return data_instances
  449. def get_filter_process(self, data_instances, hash_operator):
  450. filter = self.transfer_variable.intersect_filter_from_guest.get(idx=0)
  451. LOGGER.debug(f"got filter from guest")
  452. filtered_data = data_instances.filter(lambda k, v: filter.check(
  453. hash_operator.compute(k, suffix_salt=self.intersect_preprocess_params.preprocess_salt)))
  454. return filtered_data
  455. class IntersectGuest(IntersectModelBase):
  456. def __init__(self):
  457. super().__init__()
  458. self.role = consts.GUEST
  459. def init_intersect_method(self):
  460. super().init_intersect_method()
  461. if self.intersect_method == consts.RSA:
  462. self.intersection_obj = RsaIntersectionGuest()
  463. elif self.intersect_method == consts.RAW:
  464. self.intersection_obj = RawIntersectionGuest()
  465. self.intersection_obj.tracker = self.tracker
  466. self.intersection_obj.task_version_id = self.task_version_id
  467. elif self.intersect_method == consts.DH:
  468. self.intersection_obj = DhIntersectionGuest()
  469. elif self.intersect_method == consts.ECDH:
  470. self.intersection_obj = EcdhIntersectionGuest()
  471. else:
  472. raise ValueError("intersect_method {} is not support yet".format(self.intersect_method))
  473. self.intersection_obj.guest_party_id = self.guest_party_id
  474. self.intersection_obj.host_party_id_list = self.host_party_id_list
  475. self.intersection_obj.load_params(self.model_param)
  476. def sync_use_match_id(self):
  477. host_use_match_id_flg = self.transfer_variable.use_match_id.get(idx=-1)
  478. LOGGER.info(f"received use_match_id flag from all hosts.")
  479. if any(flg != self.use_match_id_process for flg in host_use_match_id_flg):
  480. raise ValueError(f"Not all parties' input data have match_id, please check.")
  481. def make_filter_process(self, data_instances, hash_operator):
  482. filter = self.intersection_obj.construct_filter(data_instances,
  483. self.intersect_preprocess_params.false_positive_rate,
  484. self.intersect_preprocess_params.hash_method,
  485. self.intersect_preprocess_params.random_state,
  486. hash_operator,
  487. self.intersect_preprocess_params.preprocess_salt)
  488. self.transfer_variable.intersect_filter_from_guest.remote(filter, role=consts.HOST, idx=-1)
  489. LOGGER.debug(f"filter sent to guest")
  490. return data_instances
  491. def get_filter_process(self, data_instances, hash_operator):
  492. filter_list = self.transfer_variable.intersect_filter_from_host.get(idx=-1)
  493. LOGGER.debug(f"got filter from all host")
  494. filtered_data_list = [
  495. data_instances.filter(
  496. lambda k,
  497. v: filter.check(
  498. hash_operator.compute(
  499. k,
  500. suffix_salt=self.intersect_preprocess_params.preprocess_salt))) for filter in filter_list]
  501. filtered_data = self.intersection_obj.get_common_intersection(filtered_data_list, False)
  502. return filtered_data