quantile_summaries_test.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import math
  2. import unittest
  3. import numpy as np
  4. from federatedml.feature.binning.quantile_summaries import QuantileSummaries
  5. class TestQuantileSummaries(unittest.TestCase):
  6. def setUp(self):
  7. self.percentile_rate = list(range(0, 100, 1))
  8. self.data_num = 10000
  9. np.random.seed(15)
  10. self.table = np.random.randn(self.data_num)
  11. compress_thres = 1000
  12. head_size = 500
  13. self.error = 0.00001
  14. self.quantile_summaries = QuantileSummaries(compress_thres=compress_thres,
  15. head_size=head_size,
  16. error=self.error)
  17. def test_correctness(self):
  18. for num in self.table:
  19. self.quantile_summaries.insert(num)
  20. x = sorted(self.table)
  21. for q_num in self.percentile_rate:
  22. percent = q_num / 100
  23. sk2 = self.quantile_summaries.query(percent)
  24. min_rank = math.floor((percent - 2 * self.error) * self.data_num)
  25. max_rank = math.ceil((percent + 2 * self.error) * self.data_num)
  26. if min_rank < 0:
  27. min_rank = 0
  28. if max_rank > len(x) - 1:
  29. max_rank = len(x) - 1
  30. min_value, max_value = x[min_rank], x[max_rank]
  31. try:
  32. self.assertTrue(min_value <= sk2 <= max_value)
  33. except AssertionError as e:
  34. print(f"min_value: {min_value}, max_value: {max_value}, sk2: {sk2}, percent: {percent},"
  35. f"total_max_value: {x[-1]}")
  36. raise AssertionError(e)
  37. def test_multi(self):
  38. for n in range(5):
  39. self.table = np.random.randn(self.data_num)
  40. compress_thres = 10000
  41. head_size = 5000
  42. self.quantile_summaries = QuantileSummaries(compress_thres=compress_thres,
  43. head_size=head_size,
  44. error=self.error)
  45. self.test_correctness()
  46. if __name__ == '__main__':
  47. unittest.main()