statics_test.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import math
  2. import time
  3. import unittest
  4. import uuid
  5. import numpy as np
  6. from fate_arch.session import computing_session as session
  7. from federatedml.util import consts
  8. session.init("123")
  9. from federatedml.feature.instance import Instance
  10. from federatedml.statistic.statics import MultivariateStatisticalSummary
  11. class TestStatistics(unittest.TestCase):
  12. def setUp(self):
  13. self.job_id = str(uuid.uuid1())
  14. session.init(self.job_id)
  15. self.eps = 1e-5
  16. self.count = 1000
  17. self.feature_num = 100
  18. self._dense_table, self._dense_not_inst_table, self._original_data = None, None, None
  19. def _gen_table_data(self):
  20. if self._dense_table is not None:
  21. return self._dense_table, self._dense_not_inst_table, self._original_data
  22. headers = ['x' + str(i) for i in range(self.feature_num)]
  23. dense_inst = []
  24. dense_not_inst = []
  25. original_data = 100 * np.random.random((self.count, self.feature_num))
  26. # original_data = 100 * np.zeros((self.count, self.feature_num))
  27. for i in range(self.count):
  28. features = original_data[i, :]
  29. inst = Instance(features=features)
  30. dense_inst.append((i, inst))
  31. dense_not_inst.append((i, features))
  32. dense_table = session.parallelize(dense_inst, include_key=True, partition=16)
  33. dense_not_inst_table = session.parallelize(dense_not_inst, include_key=True, partition=16)
  34. dense_table.schema = {'header': headers}
  35. dense_not_inst_table.schema = {'header': headers}
  36. self._dense_table, self._dense_not_inst_table, self._original_data = \
  37. dense_table, dense_not_inst_table, original_data
  38. return dense_table, dense_not_inst_table, original_data
  39. def _gen_missing_table(self):
  40. headers = ['x' + str(i) for i in range(self.feature_num)]
  41. dense_inst = []
  42. dense_not_inst = []
  43. original_data = 100 * np.random.random((self.count, self.feature_num))
  44. for i in range(self.count):
  45. features = original_data[i, :]
  46. if i % 2 == 0:
  47. features = np.array([np.nan] * self.feature_num)
  48. inst = Instance(features=features)
  49. dense_inst.append((i, inst))
  50. dense_not_inst.append((i, features))
  51. dense_table = session.parallelize(dense_inst, include_key=True, partition=16)
  52. dense_not_inst_table = session.parallelize(dense_not_inst, include_key=True, partition=16)
  53. dense_table.schema = {'header': headers}
  54. dense_not_inst_table.schema = {'header': headers}
  55. return dense_table, dense_not_inst_table, original_data
  56. def test_MultivariateStatisticalSummary(self):
  57. dense_table, dense_not_inst_table, original_data = self._gen_table_data()
  58. summary_obj = MultivariateStatisticalSummary(dense_table)
  59. self._test_min_max(summary_obj, original_data, dense_table)
  60. self._test_min_max(summary_obj, original_data, dense_not_inst_table)
  61. def _test_min_max(self, summary_obj, original_data, data_table):
  62. # test max, min
  63. max_array = np.max(original_data, axis=0)
  64. min_array = np.min(original_data, axis=0)
  65. mean_array = np.mean(original_data, axis=0)
  66. var_array = np.var(original_data, axis=0)
  67. std_var_array = np.std(original_data, axis=0)
  68. t0 = time.time()
  69. header = data_table.schema['header']
  70. for idx, col_name in enumerate(header):
  71. self.assertEqual(summary_obj.get_max()[col_name], max_array[idx])
  72. self.assertEqual(summary_obj.get_min()[col_name], min_array[idx])
  73. self.assertTrue(self._float_equal(summary_obj.get_mean()[col_name], mean_array[idx]))
  74. self.assertTrue(self._float_equal(summary_obj.get_variance()[col_name], var_array[idx]))
  75. self.assertTrue(self._float_equal(summary_obj.get_std_variance()[col_name], std_var_array[idx]))
  76. print("max value etc, total time: {}".format(time.time() - t0))
  77. def _float_equal(self, x, y, error=1e-6):
  78. if math.fabs(x - y) < error:
  79. return True
  80. print(f"x: {x}, y: {y}")
  81. return False
  82. # def test_median(self):
  83. # error = 0
  84. # dense_table, dense_not_inst_table, original_data = self._gen_table_data()
  85. #
  86. # sorted_matrix = np.sort(original_data, axis=0)
  87. # median_array = sorted_matrix[self.count // 2, :]
  88. # header = dense_table.schema['header']
  89. # summary_obj = MultivariateStatisticalSummary(dense_table, error=error)
  90. # t0 = time.time()
  91. #
  92. # for idx, col_name in enumerate(header):
  93. # self.assertTrue(self._float_equal(summary_obj.get_median()[col_name],
  94. # median_array[idx]))
  95. # print("median interface, total time: {}".format(time.time() - t0))
  96. #
  97. # summary_obj_2 = MultivariateStatisticalSummary(dense_not_inst_table, error=error)
  98. # t0 = time.time()
  99. # for idx, col_name in enumerate(header):
  100. # self.assertTrue(self._float_equal(summary_obj_2.get_median()[col_name],
  101. # median_array[idx]))
  102. # print("median interface, total time: {}".format(time.time() - t0))
  103. #
  104. # def test_quantile_query(self):
  105. #
  106. # dense_table, dense_not_inst_table, original_data = self._gen_table_data()
  107. #
  108. # quantile_points = [0.25, 0.5, 0.75, 1.0]
  109. # quantile_array = np.quantile(original_data, quantile_points, axis=0)
  110. # summary_obj = MultivariateStatisticalSummary(dense_table, error=0)
  111. # header = dense_table.schema['header']
  112. #
  113. # t0 = time.time()
  114. # for q_idx, q in enumerate(quantile_points):
  115. # for idx, col_name in enumerate(header):
  116. # self.assertTrue(self._float_equal(summary_obj.get_quantile_point(q)[col_name],
  117. # quantile_array[q_idx][idx],
  118. # error=3))
  119. # print("quantile interface, total time: {}".format(time.time() - t0))
  120. #
  121. # def test_missing_value(self):
  122. # dense_table, dense_not_inst_table, original_data = self._gen_missing_table()
  123. # summary_obj = MultivariateStatisticalSummary(dense_table, error=0)
  124. # t0 = time.time()
  125. # missing_result = summary_obj.get_missing_ratio()
  126. # for col_name, missing_ratio in missing_result.items():
  127. # self.assertEqual(missing_ratio, 0.5, msg="missing ratio should be 0.5")
  128. # print("calculate missing ratio, total time: {}".format(time.time() - t0))
  129. def test_moment(self):
  130. dense_table, dense_not_inst_table, original_data = self._gen_table_data()
  131. summary_obj = MultivariateStatisticalSummary(dense_table, error=0, stat_order=4, bias=False)
  132. header = dense_table.schema['header']
  133. from scipy import stats
  134. moment_3 = stats.moment(original_data, 3, axis=0)
  135. moment_4 = stats.moment(original_data, 4, axis=0)
  136. skewness = stats.skew(original_data, axis=0, bias=False)
  137. kurtosis = stats.kurtosis(original_data, axis=0, bias=False)
  138. summary_moment_3 = summary_obj.get_statics("moment_3")
  139. summary_moment_4 = summary_obj.get_statics("moment_4")
  140. static_skewness = summary_obj.get_statics("skewness")
  141. static_kurtosis = summary_obj.get_statics("kurtosis")
  142. # print(f"moment: {summary_moment_4}, moment_2: {moment_4}")
  143. for idx, col_name in enumerate(header):
  144. self.assertTrue(self._float_equal(summary_moment_3[col_name],
  145. moment_3[idx]))
  146. self.assertTrue(self._float_equal(summary_moment_4[col_name],
  147. moment_4[idx]))
  148. self.assertTrue(self._float_equal(static_skewness[col_name],
  149. skewness[idx]))
  150. self.assertTrue(self._float_equal(static_kurtosis[col_name],
  151. kurtosis[idx]))
  152. def tearDown(self):
  153. session.stop()
  154. if __name__ == '__main__':
  155. unittest.main()