base_intersect.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. from federatedml.param.intersect_param import IntersectParam
  18. from federatedml.statistic.intersect.intersect_preprocess import BitArray
  19. from federatedml.transfer_variable.transfer_class.intersection_func_transfer_variable \
  20. import IntersectionFuncTransferVariable
  21. from federatedml.util import LOGGER
  22. class Intersect(object):
  23. def __init__(self):
  24. super().__init__()
  25. self.cache_id = None
  26. self.model_param = IntersectParam()
  27. self.transfer_variable = None
  28. self.cache_transfer_variable = IntersectionFuncTransferVariable().cache_id_from_host
  29. self.filter = None
  30. self.intersect_num = None
  31. self.cache = None
  32. self.model_param_name = "IntersectModelParam"
  33. self.model_meta_name = "IntersectModelMeta"
  34. self.intersect_method = None
  35. self._guest_id = None
  36. self._host_id = None
  37. self._host_id_list = None
  38. def load_params(self, param):
  39. self.model_param = param
  40. self.only_output_key = param.only_output_key
  41. self.sync_intersect_ids = param.sync_intersect_ids
  42. self.cardinality_only = param.cardinality_only
  43. self.sync_cardinality = param.sync_cardinality
  44. self.cardinality_method = param.cardinality_method
  45. self.run_preprocess = param.run_preprocess
  46. self.intersect_preprocess_params = param.intersect_preprocess_params
  47. self.run_cache = param.run_cache
  48. @property
  49. def guest_party_id(self):
  50. return self._guest_id
  51. @guest_party_id.setter
  52. def guest_party_id(self, guest_id):
  53. if not isinstance(guest_id, int):
  54. raise ValueError("party id should be integer, but get {}".format(guest_id))
  55. self._guest_id = guest_id
  56. @property
  57. def host_party_id(self):
  58. return self._host_id
  59. @host_party_id.setter
  60. def host_party_id(self, host_id):
  61. if not isinstance(host_id, int):
  62. raise ValueError("party id should be integer, but get {}".format(host_id))
  63. self._host_id = host_id
  64. @property
  65. def host_party_id_list(self):
  66. return self._host_id_list
  67. @host_party_id_list.setter
  68. def host_party_id_list(self, host_id_list):
  69. if not isinstance(host_id_list, list):
  70. raise ValueError(
  71. "type host_party_id should be list, but get {} with {}".format(type(host_id_list), host_id_list))
  72. self._host_id_list = host_id_list
  73. def get_intersect_method_meta(self):
  74. pass
  75. def get_intersect_key(self, party_id):
  76. pass
  77. def load_intersect_key(self, cache_meta):
  78. pass
  79. def run_intersect(self, data_instances):
  80. raise NotImplementedError("method should not be called here")
  81. def run_cardinality(self, data_instances):
  82. raise NotImplementedError("method should not be called here")
  83. def generate_cache(self, data_instances):
  84. raise NotImplementedError("method should not be called here")
  85. @staticmethod
  86. def extract_cache_list(cache_data, party_list):
  87. if not isinstance(party_list, list):
  88. party_list = [party_list]
  89. cache_list = [cache_data.get(str(party_id)) for party_id in party_list]
  90. if (cache_len := len(cache_list)) != (data_len := len(cache_data.items())):
  91. LOGGER.warning(f"{cache_len} cache sets are given,"
  92. f"but only {data_len} hosts participate in current intersection task.")
  93. return cache_list
  94. def run_cache_intersect(self, data_instances, cache_data):
  95. raise NotImplementedError("method should not be called here")
  96. def set_flowid(self, flowid=0):
  97. if self.transfer_variable is not None:
  98. self.transfer_variable.set_flowid(flowid)
  99. @staticmethod
  100. def get_value_from_data(intersect_ids, data_instances):
  101. if intersect_ids is not None:
  102. intersect_ids = intersect_ids.join(data_instances, lambda i, d: d)
  103. intersect_ids.schema = data_instances.schema
  104. LOGGER.info("obtain intersect data_instances!")
  105. return intersect_ids
  106. @staticmethod
  107. def get_common_intersection(intersect_ids_list: list, keep_encrypt_ids=False):
  108. if len(intersect_ids_list) == 1:
  109. return intersect_ids_list[0]
  110. if keep_encrypt_ids:
  111. def f(v_prev, v): return v_prev + v
  112. else:
  113. def f(v_prev, v): return None
  114. intersect_ids = None
  115. for i, value in enumerate(intersect_ids_list):
  116. if intersect_ids is None:
  117. intersect_ids = value
  118. continue
  119. intersect_ids = intersect_ids.join(value, f)
  120. return intersect_ids
  121. @staticmethod
  122. def extract_intersect_ids(intersect_ids, all_ids, keep_both=False):
  123. if keep_both:
  124. intersect_ids = intersect_ids.join(all_ids, lambda e, h: [e, h])
  125. else:
  126. intersect_ids = intersect_ids.join(all_ids, lambda e, h: h)
  127. return intersect_ids
  128. @staticmethod
  129. def filter_intersect_ids(encrypt_intersect_ids, keep_encrypt_ids=False):
  130. if keep_encrypt_ids:
  131. def f(k, v): return (v, [k])
  132. else:
  133. def f(k, v): return (v, None)
  134. if len(encrypt_intersect_ids) > 1:
  135. raw_intersect_ids = [e.map(f) for e in encrypt_intersect_ids]
  136. intersect_ids = Intersect.get_common_intersection(raw_intersect_ids, keep_encrypt_ids)
  137. else:
  138. intersect_ids = encrypt_intersect_ids[0]
  139. intersect_ids = intersect_ids.map(f)
  140. return intersect_ids
  141. @staticmethod
  142. def map_raw_id_to_encrypt_id(raw_id_data, encrypt_id_data, keep_value=False):
  143. encrypt_id_data_exchange_kv = encrypt_id_data.map(lambda k, v: (v, k))
  144. encrypt_raw_id = raw_id_data.join(encrypt_id_data_exchange_kv, lambda r, e: (e, r))
  145. if keep_value:
  146. encrypt_common_id = encrypt_raw_id.map(lambda k, v: (v[0], v[1]))
  147. else:
  148. encrypt_common_id = encrypt_raw_id.map(lambda k, v: (v[0], None))
  149. return encrypt_common_id
  150. @staticmethod
  151. def map_encrypt_id_to_raw_id(encrypt_id_data, raw_id_data, keep_encrypt_id=True):
  152. """
  153. Parameters
  154. ----------
  155. encrypt_id_data: E(id)
  156. raw_id_data: (E(id), (id, v))
  157. keep_encrypt_id: bool
  158. Returns
  159. -------
  160. (id, E(id))
  161. """
  162. encrypt_id_raw_id = raw_id_data.join(encrypt_id_data, lambda r, e: r)
  163. if keep_encrypt_id:
  164. raw_id = encrypt_id_raw_id.map(lambda k, v: (v[0], k))
  165. else:
  166. raw_id = encrypt_id_raw_id.map(lambda k, v: (v[0], None))
  167. return raw_id
  168. @staticmethod
  169. def hash(value, hash_operator, salt=''):
  170. h_value = hash_operator.compute(value, suffix_salt=salt)
  171. return h_value
  172. @staticmethod
  173. def insert_key(kv_iterator, filter, hash_operator=None, salt=None):
  174. res_filter = None
  175. for k, _ in kv_iterator:
  176. if hash_operator:
  177. res_filter = filter.insert(hash_operator.compute(k, suffix_salt=salt))
  178. else:
  179. res_filter = filter.insert(k)
  180. return res_filter
  181. @staticmethod
  182. def count_key_in_filter(kv_iterator, filter):
  183. count = 0
  184. for k, _ in kv_iterator:
  185. count += filter.check(k)
  186. return count
  187. @staticmethod
  188. def construct_filter(data, false_positive_rate, hash_method, random_state, hash_operator=None, salt=None):
  189. n = data.count()
  190. m, k = BitArray.get_filter_param(n, false_positive_rate)
  191. filter = BitArray(m, k, hash_method, random_state)
  192. LOGGER.debug(f"filter bit count is: {filter.bit_count}")
  193. LOGGER.debug(f"filter hash func count: {filter.hash_func_count}")
  194. f = functools.partial(Intersect.insert_key, filter=filter, hash_operator=hash_operator, salt=salt)
  195. new_array = data.mapPartitions(f).reduce(lambda x, y: x | y)
  196. LOGGER.debug(f"filter array obtained")
  197. filter.set_array(new_array)
  198. # LOGGER.debug(f"after insert, filter sparsity is: {filter.sparsity}")
  199. return filter