test_psi.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import unittest
  2. import numpy as np
  3. import time
  4. import copy
  5. import uuid
  6. from fate_arch.session import computing_session as session
  7. from federatedml.feature.instance import Instance
  8. from federatedml.feature.sparse_vector import SparseVector
  9. from federatedml.statistic.psi.psi import PSI
  10. from federatedml.param.psi_param import PSIParam
  11. class TestPSI(unittest.TestCase):
  12. def setUp(self):
  13. session.init('test', 0)
  14. print('generating dense tables')
  15. l1, l2 = [], []
  16. col = [i for i in range(20)]
  17. for i in range(100):
  18. inst = Instance()
  19. inst.features = np.random.random(20)
  20. l1.append(inst)
  21. for i in range(1000):
  22. inst = Instance()
  23. inst.features = np.random.random(20)
  24. l2.append(inst)
  25. self.dense_table1, self.dense_table2 = session.parallelize(l1, partition=4, include_key=False), \
  26. session.parallelize(l2, partition=4, include_key=False)
  27. self.dense_table1.schema['header'] = copy.deepcopy(col)
  28. self.dense_table2.schema['header'] = copy.deepcopy(col)
  29. print('generating done')
  30. print('generating sparse tables')
  31. l1, l2 = [], []
  32. col = [i for i in range(20)]
  33. for i in range(100):
  34. inst = Instance()
  35. inst.features = SparseVector(indices=copy.deepcopy(col), data=list(np.random.random(20)))
  36. l1.append(inst)
  37. for i in range(1000):
  38. inst = Instance()
  39. inst.features = SparseVector(indices=copy.deepcopy(col), data=list(np.random.random(20)))
  40. l2.append(inst)
  41. self.sp_table1, self.sp_table2 = session.parallelize(l1, partition=4, include_key=False), \
  42. session.parallelize(l2, partition=4, include_key=False)
  43. self.sp_table1.schema['header'] = copy.deepcopy(col)
  44. self.sp_table2.schema['header'] = copy.deepcopy(col)
  45. print('generating done')
  46. def test_dense_psi(self):
  47. param = PSIParam()
  48. psi = PSI()
  49. psi._init_model(param)
  50. psi.fit(self.dense_table1, self.dense_table2)
  51. print('dense testing done')
  52. def test_sparse_psi(self):
  53. param = PSIParam()
  54. psi = PSI()
  55. psi._init_model(param)
  56. psi.fit(self.sp_table1, self.sp_table2)
  57. print('dense testing done')
  58. if __name__ == "__main__":
  59. unittest.main()