base_binning_test.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import unittest
  18. import numpy as np
  19. from fate_arch.session import computing_session as session
  20. session.init("123")
  21. from federatedml.feature.instance import Instance
  22. from federatedml.statistic.statics import MultivariateStatisticalSummary
  23. class TestBaseBinningFunctions(unittest.TestCase):
  24. def setUp(self):
  25. self.table_list = []
  26. def _gen_data(self, label_histogram: dict, partition=10):
  27. label_list = []
  28. data_num = 0
  29. for y, n in label_histogram.items():
  30. data_num += n
  31. label_list.extend([y] * n)
  32. np.random.shuffle(label_list)
  33. data_insts = []
  34. for i in range(data_num):
  35. features = np.random.randn(10)
  36. inst = Instance(features=features, label=label_list[i])
  37. data_insts.append((i, inst))
  38. result = session.parallelize(data_insts, include_key=True, partition=partition)
  39. result.schema = {'header': ['d' + str(x) for x in range(10)]}
  40. self.table_list.append(result)
  41. return result
  42. def test_histogram(self):
  43. histograms = [
  44. {0: 100, 1: 100},
  45. {0: 9700, 1: 300},
  46. {0: 2000, 1: 18000},
  47. {0: 8000, 1: 2000}
  48. ]
  49. partitions = [10, 1, 48, 32]
  50. for i, h in enumerate(histograms):
  51. data = self._gen_data(h, partitions[i])
  52. summary_obj = MultivariateStatisticalSummary(data_instances=data)
  53. label_hist = summary_obj.get_label_histogram()
  54. self.assertDictEqual(h, label_hist)
  55. def tearDown(self):
  56. # for table in self.table_list:
  57. # table.destroy()
  58. session.stop()
  59. if __name__ == '__main__':
  60. unittest.main()