iv_calculator_test.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. #
  17. import unittest
  18. import uuid
  19. import numpy as np
  20. from fate_arch.session import computing_session as session
  21. from fate_arch.session import Session
  22. from federatedml.feature.binning.quantile_binning import QuantileBinning
  23. from federatedml.feature.binning.iv_calculator import IvCalculator
  24. from federatedml.param.feature_binning_param import FeatureBinningParam
  25. from federatedml.feature.instance import Instance
  26. from federatedml.feature.sparse_vector import SparseVector
  27. from federatedml.util import consts
  28. bin_num = 10
  29. class TestIvCalculator(unittest.TestCase):
  30. def setUp(self):
  31. self.job_id = str(uuid.uuid1())
  32. # session = Session.create(0, 0).init_computing("abc").computing
  33. session.init(self.job_id)
  34. def test_iv_calculator(self):
  35. bin_obj = self._bin_obj_generator()
  36. small_table = self.gen_data(10000, 50, 2)
  37. split_points = bin_obj.fit_split_points(small_table)
  38. iv_calculator = IvCalculator(adjustment_factor=0.5, role="guest", party_id=9999)
  39. ivs = iv_calculator.cal_local_iv(small_table, split_points)
  40. print(f"iv result: {ivs.summary()}")
  41. # def test_sparse_data(self):
  42. # feature_num = 50
  43. # bin_obj = self._bin_obj_generator()
  44. # small_table = self.gen_data(10000, feature_num, 2, is_sparse=True)
  45. # split_points = bin_obj.fit_split_points(small_table)
  46. # expect_split_points = list((range(1, bin_num)))
  47. # expect_split_points = [float(x) for x in expect_split_points]
  48. #
  49. # for feature_name, s_ps in split_points.items():
  50. # if int(feature_name) >= feature_num:
  51. # continue
  52. # s_ps = s_ps.tolist()
  53. # self.assertListEqual(s_ps, expect_split_points)
  54. #
  55. # def test_abnormal(self):
  56. # abnormal_list = [3, 4]
  57. # bin_obj = self._bin_obj_generator(abnormal_list=abnormal_list, this_bin_num=bin_num - len(abnormal_list))
  58. # small_table = self.gen_data(10000, 50, 2)
  59. # split_points = bin_obj.fit_split_points(small_table)
  60. # expect_split_points = list((range(1, bin_num)))
  61. # expect_split_points = [float(x) for x in expect_split_points if x not in abnormal_list]
  62. #
  63. # for _, s_ps in split_points.items():
  64. # s_ps = s_ps.tolist()
  65. # self.assertListEqual(s_ps, expect_split_points)
  66. #
  67. def _bin_obj_generator(self, abnormal_list: list = None, this_bin_num=bin_num):
  68. bin_param = FeatureBinningParam(method='quantile', compress_thres=consts.DEFAULT_COMPRESS_THRESHOLD,
  69. head_size=consts.DEFAULT_HEAD_SIZE,
  70. error=consts.DEFAULT_RELATIVE_ERROR,
  71. bin_indexes=-1,
  72. bin_num=this_bin_num)
  73. bin_obj = QuantileBinning(bin_param, abnormal_list=abnormal_list)
  74. return bin_obj
  75. def gen_data(self, data_num, feature_num, partition, is_sparse=False, use_random=False):
  76. data = []
  77. shift_iter = 0
  78. header = [str(i) for i in range(feature_num)]
  79. anonymous_header = ["guest_9999_x" + str(i) for i in range(feature_num)]
  80. for data_key in range(data_num):
  81. value = data_key % bin_num
  82. if value == 0:
  83. if shift_iter % bin_num == 0:
  84. value = bin_num - 1
  85. shift_iter += 1
  86. if not is_sparse:
  87. if not use_random:
  88. features = value * np.ones(feature_num)
  89. else:
  90. features = np.random.random(feature_num)
  91. inst = Instance(inst_id=data_key, features=features, label=data_key % 2)
  92. else:
  93. if not use_random:
  94. features = value * np.ones(feature_num)
  95. else:
  96. features = np.random.random(feature_num)
  97. data_index = [x for x in range(feature_num)]
  98. sparse_inst = SparseVector(data_index, data=features, shape=10 * feature_num)
  99. inst = Instance(inst_id=data_key, features=sparse_inst, label=data_key % 2)
  100. header = [str(i) for i in range(feature_num * 10)]
  101. data.append((data_key, inst))
  102. result = session.parallelize(data, include_key=True, partition=partition)
  103. result.schema = {'header': header,
  104. "anonymous_header": anonymous_header}
  105. return result
  106. def tearDown(self):
  107. session.stop()
  108. if __name__ == '__main__':
  109. unittest.main()