quantile_test.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import unittest
  17. import numpy as np
  18. import random
  19. from fate_arch.session import computing_session as session
  20. import uuid
  21. from federatedml.feature.binning.quantile_binning import QuantileBinning
  22. from federatedml.feature.instance import Instance
  23. # from federatedml.feature.quantile import Quantile
  24. from federatedml.feature.sparse_vector import SparseVector
  25. from federatedml.param.feature_binning_param import FeatureBinningParam
  26. class TestInstance(unittest.TestCase):
  27. def setUp(self):
  28. self.job_id = str(uuid.uuid1())
  29. session.init(self.job_id)
  30. # session.init("test_instance")
  31. def gen_data(self):
  32. dense_inst = []
  33. headers = ['x' + str(i) for i in range(20)]
  34. anonymous_header = ["guest_9999_x" + str(i) for i in range(20)]
  35. for i in range(100):
  36. inst = Instance(features=(i % 16 * np.ones(20)))
  37. dense_inst.append((i, inst))
  38. self.dense_table = session.parallelize(dense_inst, include_key=True, partition=2)
  39. self.dense_table.schema = {'header': headers,
  40. "anonymous_header": anonymous_header}
  41. self.sparse_inst = []
  42. for i in range(100):
  43. dict = {}
  44. indices = []
  45. data = []
  46. for j in range(20):
  47. idx = random.randint(0, 29)
  48. if idx in dict:
  49. continue
  50. dict[idx] = 1
  51. val = random.random()
  52. indices.append(idx)
  53. data.append(val)
  54. sparse_vec = SparseVector(indices, data, 30)
  55. self.sparse_inst.append((i, Instance(features=sparse_vec)))
  56. self.sparse_table = session.parallelize(self.sparse_inst, include_key=True, partition=48)
  57. self.sparse_table.schema = {"header": ["fid" + str(i) for i in range(30)]}
  58. # self.sparse_table = eggroll.parallelize(sparse_inst, include_key=True, partition=1)
  59. """
  60. def test_dense_quantile(self):
  61. data_bin, bin_splitpoints, bin_sparse = Quantile.convert_feature_to_bin(self.dense_table, "bin_by_sample_data",
  62. bin_num=4)
  63. bin_result = dict([(key, inst.features) for key, inst in data_bin.collect()])
  64. for i in range(100):
  65. self.assertTrue((bin_result[i] == np.ones(20, dtype='int') * ((i % 16) // 4)).all())
  66. if i < 20:
  67. self.assertTrue((bin_splitpoints[i] == np.asarray([3, 7, 11, 15], dtype='int')).all())
  68. data_bin, bin_splitpoints, bin_sparse = Quantile.convert_feature_to_bin(self.dense_table, "bin_by_data_block",
  69. bin_num=4)
  70. for i in range(20):
  71. self.assertTrue(bin_splitpoints[i].shape[0] <= 4)
  72. def test_sparse_quantile(self):
  73. data_bin, bin_splitpoints, bin_sparse = Quantile.convert_feature_to_bin(self.sparse_table, "bin_by_sample_data",
  74. bin_num=4)
  75. bin_result = dict([(key, inst.features) for key, inst in data_bin.collect()])
  76. for i in range(20):
  77. self.assertTrue(len(self.sparse_inst[i][1].features.sparse_vec) == len(bin_result[i].sparse_vec))
  78. """
  79. """
  80. def test_new_sparse_quantile(self):
  81. self.gen_data()
  82. param_obj = FeatureBinningParam(bin_num=4)
  83. binning_obj = QuantileBinning(param_obj)
  84. binning_obj.fit_split_points(self.sparse_table)
  85. data_bin, bin_splitpoints, bin_sparse = binning_obj.convert_feature_to_bin(self.sparse_table)
  86. bin_result = dict([(key, inst.features) for key, inst in data_bin.collect()])
  87. for i in range(20):
  88. self.assertTrue(len(self.sparse_inst[i][1].features.sparse_vec) == len(bin_result[i].sparse_vec))
  89. """
  90. def test_new_dense_quantile(self):
  91. self.gen_data()
  92. param_obj = FeatureBinningParam(bin_num=4)
  93. binning_obj = QuantileBinning(param_obj)
  94. binning_obj.fit_split_points(self.dense_table)
  95. data_bin, bin_splitpoints, bin_sparse = binning_obj.convert_feature_to_bin(self.dense_table)
  96. bin_result = dict([(key, inst.features) for key, inst in data_bin.collect()])
  97. # print(bin_result)
  98. for i in range(100):
  99. self.assertTrue((bin_result[i] == np.ones(20, dtype='int') * ((i % 16) // 4)).all())
  100. if i < 20:
  101. # col_name = 'x' + str(i)
  102. col_idx = i
  103. split_point = np.array(bin_splitpoints[col_idx])
  104. self.assertTrue((split_point == np.asarray([3, 7, 11, 15], dtype='int')).all())
  105. for split_points in bin_splitpoints:
  106. self.assertTrue(len(split_points) <= 4)
  107. def tearDown(self):
  108. session.stop()
  109. # try:
  110. # session.cleanup("*", self.job_id, True)
  111. # except EnvironmentError:
  112. # pass
  113. # try:
  114. # session.cleanup("*", self.job_id, False)
  115. # except EnvironmentError:
  116. # pass
  117. if __name__ == '__main__':
  118. unittest.main()