quantile_binning_test.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. #
  17. import unittest
  18. import uuid
  19. import numpy as np
  20. from fate_arch.session import computing_session as session
  21. from fate_arch.session import Session
  22. from federatedml.feature.binning.quantile_binning import QuantileBinning
  23. from federatedml.param.feature_binning_param import FeatureBinningParam
  24. from federatedml.feature.instance import Instance
  25. from federatedml.feature.sparse_vector import SparseVector
  26. from federatedml.util import consts
  27. bin_num = 10
  28. TEST_LARGE_DATA = False
  29. # job_id = str(uuid.uuid1())
  30. # session.init(job_id, 1)
  31. class TestQuantileBinning(unittest.TestCase):
  32. def setUp(self):
  33. self.job_id = str(uuid.uuid1())
  34. # session = Session.create(0, 0).init_computing("abc").computing
  35. session.init(self.job_id)
  36. def test_binning_correctness(self):
  37. bin_obj = self._bin_obj_generator()
  38. small_table = self.gen_data(10000, 50, 2)
  39. split_points = bin_obj.fit_split_points(small_table)
  40. expect_split_points = list((range(1, bin_num)))
  41. expect_split_points = [float(x) for x in expect_split_points]
  42. for _, s_ps in split_points.items():
  43. s_ps = s_ps.tolist()
  44. self.assertListEqual(s_ps, expect_split_points)
  45. def test_large_binning(self):
  46. if TEST_LARGE_DATA:
  47. bin_obj = self._bin_obj_generator()
  48. small_table = self.gen_data(100000, 1000, 48, use_random=True)
  49. _ = bin_obj.fit_split_points(small_table)
  50. def test_sparse_data(self):
  51. feature_num = 50
  52. bin_obj = self._bin_obj_generator()
  53. small_table = self.gen_data(10000, feature_num, 2, is_sparse=True)
  54. split_points = bin_obj.fit_split_points(small_table)
  55. expect_split_points = list((range(1, bin_num)))
  56. expect_split_points = [float(x) for x in expect_split_points]
  57. for feature_name, s_ps in split_points.items():
  58. if int(feature_name) >= feature_num:
  59. continue
  60. s_ps = s_ps.tolist()
  61. self.assertListEqual(s_ps, expect_split_points)
  62. def test_abnormal(self):
  63. abnormal_list = [3, 4]
  64. bin_obj = self._bin_obj_generator(abnormal_list=abnormal_list, this_bin_num=bin_num - len(abnormal_list))
  65. small_table = self.gen_data(10000, 50, 2)
  66. split_points = bin_obj.fit_split_points(small_table)
  67. expect_split_points = list((range(1, bin_num)))
  68. expect_split_points = [float(x) for x in expect_split_points if x not in abnormal_list]
  69. for _, s_ps in split_points.items():
  70. s_ps = s_ps.tolist()
  71. self.assertListEqual(s_ps, expect_split_points)
  72. def _bin_obj_generator(self, abnormal_list: list = None, this_bin_num=bin_num):
  73. bin_param = FeatureBinningParam(method='quantile', compress_thres=consts.DEFAULT_COMPRESS_THRESHOLD,
  74. head_size=consts.DEFAULT_HEAD_SIZE,
  75. error=consts.DEFAULT_RELATIVE_ERROR,
  76. bin_indexes=-1,
  77. bin_num=this_bin_num)
  78. bin_obj = QuantileBinning(bin_param, abnormal_list=abnormal_list)
  79. return bin_obj
  80. def gen_data(self, data_num, feature_num, partition, is_sparse=False, use_random=False):
  81. data = []
  82. shift_iter = 0
  83. header = [str(i) for i in range(feature_num)]
  84. anonymous_header = ["guest_9999_x" + str(i) for i in range(feature_num)]
  85. for data_key in range(data_num):
  86. value = data_key % bin_num
  87. if value == 0:
  88. if shift_iter % bin_num == 0:
  89. value = bin_num - 1
  90. shift_iter += 1
  91. if not is_sparse:
  92. if not use_random:
  93. features = value * np.ones(feature_num)
  94. else:
  95. features = np.random.random(feature_num)
  96. inst = Instance(inst_id=data_key, features=features, label=data_key % 2)
  97. else:
  98. if not use_random:
  99. features = value * np.ones(feature_num)
  100. else:
  101. features = np.random.random(feature_num)
  102. data_index = [x for x in range(feature_num)]
  103. sparse_inst = SparseVector(data_index, data=features, shape=10 * feature_num)
  104. inst = Instance(inst_id=data_key, features=sparse_inst, label=data_key % 2)
  105. header = [str(i) for i in range(feature_num * 10)]
  106. data.append((data_key, inst))
  107. result = session.parallelize(data, include_key=True, partition=partition)
  108. result.schema = {'header': header,
  109. "anonymous_header": anonymous_header}
  110. return result
  111. def tearDown(self):
  112. session.stop()
  113. # try:
  114. # session.cleanup("*", self.job_id, True)
  115. # except EnvironmentError:
  116. # pass
  117. # try:
  118. # session.cleanup("*", self.job_id, False)
  119. # except EnvironmentError:
  120. # pass
  121. if __name__ == '__main__':
  122. unittest.main()