123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import gmpy2
- import math
- import uuid
- import numpy as np
- from federatedml.secureprotol.hash.hash_factory import Hash
- from federatedml.util import consts, LOGGER
- SALT_LENGTH = 8
- class BitArray(object):
- def __init__(self, bit_count, hash_func_count, hash_method, random_state, salt=None):
- self.bit_count = bit_count
- self._array = np.zeros((bit_count + 63) // 64, dtype='uint64')
- self.bit_count = self._array.size * 64
- self.random_state = random_state
- # self.hash_encoder = Hash(hash_method, False)
- self.hash_method = hash_method
- self.hash_func_count = hash_func_count
- self.id = str(uuid.uuid4())
- self.salt = salt
- if salt is None:
- self.salt = self._generate_salt()
- def _generate_salt(self):
- random_state = np.random.RandomState(self.random_state)
- def f(n):
- return str(n)[2:]
- return list(map(f, np.round(random_state.random(self.hash_func_count), SALT_LENGTH)))
- @property
- def sparsity(self):
- set_bit_count = sum(map(gmpy2.popcount, map(int, self._array)))
- return 1 - set_bit_count / self.bit_count
- def set_array(self, new_array):
- self._array = new_array
- def get_array(self):
- return self._array
- def merge_filter(self, other):
- if self.bit_count != other.bit_count:
- raise ValueError(f"cannot merge filters with different bit count")
- self._array |= other._array
- def get_ind_set(self, x):
- hash_encoder = Hash(self.hash_method, False)
- return set(int(hash_encoder.compute(x,
- suffix_salt=self.salt[i]),
- 16) % self.bit_count for i in range(self.hash_func_count))
- def insert(self, x):
- """
- insert given instance to bit array with hash functions
- Parameters
- ----------
- x
- Returns
- -------
- """
- ind_set = self.get_ind_set(x)
- for ind in ind_set:
- self.set_bit(ind)
- return self._array
- def insert_ind_set(self, ind_set):
- """
- insert given ind collection to bit array with hash functions
- Parameters
- ----------
- ind_set
- Returns
- -------
- """
- for ind in ind_set:
- self.set_bit(ind)
- def check(self, x):
- """
- check whether given instance x exists in bit array
- Parameters
- ----------
- x
- Returns
- -------
- """
- hash_encoder = Hash(self.hash_method, False)
- for i in range(self.hash_func_count):
- ind = int(hash_encoder.compute(x, suffix_salt=self.salt[i]), 16) % self.bit_count
- if not self.query_bit(ind):
- return False
- return True
- def check_ind_set(self, ind_set):
- """
- check whether all bits in given ind set are filled
- Parameters
- ----------
- ind_set
- Returns
- -------
- """
- for ind in ind_set:
- if not self.query_bit(ind):
- return False
- return True
- def set_bit(self, ind):
- """
- set bit at given bit index
- Parameters
- ----------
- ind
- Returns
- -------
- """
- pos = ind >> 6
- bit_pos = ind & 63
- self._array[pos] |= np.uint64(1 << bit_pos)
- def query_bit(self, ind):
- """
- query bit != 0
- Parameters
- ----------
- ind
- Returns
- -------
- """
- pos = ind >> 6
- bit_pos = ind & 63
- return (self._array[pos] & np.uint64(1 << bit_pos)) != 0
- @staticmethod
- def get_filter_param(n, p):
- """
- Parameters
- ----------
- n: items to store in filter
- p: target false positive rate
- Returns
- -------
- """
- # bit count
- m = math.ceil(-n * math.log(p) / (math.pow(math.log(2), 2)))
- # hash func count
- k = round(m / n * math.log(2))
- if k < consts.MIN_HASH_FUNC_COUNT:
- LOGGER.info(f"computed k value {k} is smaller than min hash func count limit, "
- f"set to {consts.MIN_HASH_FUNC_COUNT}")
- k = consts.MIN_HASH_FUNC_COUNT
- # update bit count so that target fpr = p
- m = round(-n * k / math.log(1 - math.pow(p, 1 / k)))
- if k > consts.MAX_HASH_FUNC_COUNT:
- LOGGER.info(f"computed k value {k} is greater than max hash func count limit, "
- f"set to {consts.MAX_HASH_FUNC_COUNT}")
- k = consts.MAX_HASH_FUNC_COUNT
- # update bit count so that target fpr = p
- m = round(-n * k / math.log(1 - math.pow(p, 1 / k)))
- return m, k
|