statistic_cpn_test.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import unittest
  2. import uuid
  3. import numpy as np
  4. from fate_arch.session import computing_session as session
  5. from federatedml.param.statistics_param import StatisticsParam
  6. from federatedml.statistic.data_statistics import DataStatistics
  7. from federatedml.feature.instance import Instance
  8. class TestStatisticCpn(unittest.TestCase):
  9. def setUp(self):
  10. self.job_id = str(uuid.uuid1())
  11. session.init(self.job_id)
  12. def gen_data(self, data_num, partition):
  13. data = []
  14. header = [str(i) for i in range(2)]
  15. anonymous_header = ["guest_9999_x" + str(i) for i in range(2)]
  16. col_1 = np.random.randn(data_num)
  17. col_2 = np.random.rand(data_num)
  18. for key in range(data_num):
  19. data.append((key, Instance(features=np.array([col_1[key], col_2[key]]))))
  20. result = session.parallelize(data, include_key=True, partition=partition)
  21. result.schema = {'header': header,
  22. "anonymous_header": anonymous_header}
  23. self.header = header
  24. self.col_1 = col_1
  25. self.col_2 = col_2
  26. return result
  27. def test_something(self):
  28. statistics_param = StatisticsParam(statistics="summary")
  29. statistics_param.check()
  30. print(statistics_param.statistics)
  31. test_data = self.gen_data(1000, 16)
  32. test_obj = DataStatistics()
  33. test_obj.model_param = statistics_param
  34. test_obj._init_model(statistics_param)
  35. test_obj.fit(test_data)
  36. static_result = test_obj.summary()
  37. stat_res_1 = static_result[self.header[0]]
  38. self.assertTrue(self._float_equal(stat_res_1['sum'], np.sum(self.col_1)))
  39. self.assertTrue(self._float_equal(stat_res_1['max'], np.max(self.col_1)))
  40. self.assertTrue(self._float_equal(stat_res_1['mean'], np.mean(self.col_1)))
  41. self.assertTrue(self._float_equal(stat_res_1['stddev'], np.std(self.col_1)))
  42. self.assertTrue(self._float_equal(stat_res_1['min'], np.min(self.col_1)))
  43. # self.assertEqual(True, False)
  44. def _float_equal(self, x, y, error=1e-6):
  45. if np.fabs(x - y) < error:
  46. return True
  47. print(f"x: {x}, y: {y}")
  48. return False
  49. def tearDown(self):
  50. session.stop()
  51. if __name__ == '__main__':
  52. unittest.main()