intersect_preprocess.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import gmpy2
  17. import math
  18. import uuid
  19. import numpy as np
  20. from federatedml.secureprotol.hash.hash_factory import Hash
  21. from federatedml.util import consts, LOGGER
  22. SALT_LENGTH = 8
  23. class BitArray(object):
  24. def __init__(self, bit_count, hash_func_count, hash_method, random_state, salt=None):
  25. self.bit_count = bit_count
  26. self._array = np.zeros((bit_count + 63) // 64, dtype='uint64')
  27. self.bit_count = self._array.size * 64
  28. self.random_state = random_state
  29. # self.hash_encoder = Hash(hash_method, False)
  30. self.hash_method = hash_method
  31. self.hash_func_count = hash_func_count
  32. self.id = str(uuid.uuid4())
  33. self.salt = salt
  34. if salt is None:
  35. self.salt = self._generate_salt()
  36. def _generate_salt(self):
  37. random_state = np.random.RandomState(self.random_state)
  38. def f(n):
  39. return str(n)[2:]
  40. return list(map(f, np.round(random_state.random(self.hash_func_count), SALT_LENGTH)))
  41. @property
  42. def sparsity(self):
  43. set_bit_count = sum(map(gmpy2.popcount, map(int, self._array)))
  44. return 1 - set_bit_count / self.bit_count
  45. def set_array(self, new_array):
  46. self._array = new_array
  47. def get_array(self):
  48. return self._array
  49. def merge_filter(self, other):
  50. if self.bit_count != other.bit_count:
  51. raise ValueError(f"cannot merge filters with different bit count")
  52. self._array |= other._array
  53. def get_ind_set(self, x):
  54. hash_encoder = Hash(self.hash_method, False)
  55. return set(int(hash_encoder.compute(x,
  56. suffix_salt=self.salt[i]),
  57. 16) % self.bit_count for i in range(self.hash_func_count))
  58. def insert(self, x):
  59. """
  60. insert given instance to bit array with hash functions
  61. Parameters
  62. ----------
  63. x
  64. Returns
  65. -------
  66. """
  67. ind_set = self.get_ind_set(x)
  68. for ind in ind_set:
  69. self.set_bit(ind)
  70. return self._array
  71. def insert_ind_set(self, ind_set):
  72. """
  73. insert given ind collection to bit array with hash functions
  74. Parameters
  75. ----------
  76. ind_set
  77. Returns
  78. -------
  79. """
  80. for ind in ind_set:
  81. self.set_bit(ind)
  82. def check(self, x):
  83. """
  84. check whether given instance x exists in bit array
  85. Parameters
  86. ----------
  87. x
  88. Returns
  89. -------
  90. """
  91. hash_encoder = Hash(self.hash_method, False)
  92. for i in range(self.hash_func_count):
  93. ind = int(hash_encoder.compute(x, suffix_salt=self.salt[i]), 16) % self.bit_count
  94. if not self.query_bit(ind):
  95. return False
  96. return True
  97. def check_ind_set(self, ind_set):
  98. """
  99. check whether all bits in given ind set are filled
  100. Parameters
  101. ----------
  102. ind_set
  103. Returns
  104. -------
  105. """
  106. for ind in ind_set:
  107. if not self.query_bit(ind):
  108. return False
  109. return True
  110. def set_bit(self, ind):
  111. """
  112. set bit at given bit index
  113. Parameters
  114. ----------
  115. ind
  116. Returns
  117. -------
  118. """
  119. pos = ind >> 6
  120. bit_pos = ind & 63
  121. self._array[pos] |= np.uint64(1 << bit_pos)
  122. def query_bit(self, ind):
  123. """
  124. query bit != 0
  125. Parameters
  126. ----------
  127. ind
  128. Returns
  129. -------
  130. """
  131. pos = ind >> 6
  132. bit_pos = ind & 63
  133. return (self._array[pos] & np.uint64(1 << bit_pos)) != 0
  134. @staticmethod
  135. def get_filter_param(n, p):
  136. """
  137. Parameters
  138. ----------
  139. n: items to store in filter
  140. p: target false positive rate
  141. Returns
  142. -------
  143. """
  144. # bit count
  145. m = math.ceil(-n * math.log(p) / (math.pow(math.log(2), 2)))
  146. # hash func count
  147. k = round(m / n * math.log(2))
  148. if k < consts.MIN_HASH_FUNC_COUNT:
  149. LOGGER.info(f"computed k value {k} is smaller than min hash func count limit, "
  150. f"set to {consts.MIN_HASH_FUNC_COUNT}")
  151. k = consts.MIN_HASH_FUNC_COUNT
  152. # update bit count so that target fpr = p
  153. m = round(-n * k / math.log(1 - math.pow(p, 1 / k)))
  154. if k > consts.MAX_HASH_FUNC_COUNT:
  155. LOGGER.info(f"computed k value {k} is greater than max hash func count limit, "
  156. f"set to {consts.MAX_HASH_FUNC_COUNT}")
  157. k = consts.MAX_HASH_FUNC_COUNT
  158. # update bit count so that target fpr = p
  159. m = round(-n * k / math.log(1 - math.pow(p, 1 / k)))
  160. return m, k