sparse_vector_test.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import unittest
  17. from federatedml.feature.sparse_vector import SparseVector
  18. class TestSparseVector(unittest.TestCase):
  19. def setUp(self):
  20. pass
  21. def test_instance(self):
  22. indices = []
  23. data = []
  24. for i in range(1, 10):
  25. indices.append(i * i)
  26. data.append(i ** 3)
  27. shape = 100
  28. sparse_data = SparseVector(indices, data, shape)
  29. self.assertTrue(sparse_data.shape == shape and len(sparse_data.sparse_vec) == 9)
  30. self.assertTrue(sparse_data.count_zeros() == 91)
  31. self.assertTrue(sparse_data.count_non_zeros() == 9)
  32. for idx, val in zip(indices, data):
  33. self.assertTrue(sparse_data.get_data(idx) == val)
  34. for i in range(100):
  35. if i in indices:
  36. continue
  37. self.assertTrue(sparse_data.get_data(i, i ** 4) == i ** 4)
  38. self.assertTrue(dict(sparse_data.get_all_data()) == dict(zip(indices, data)))
  39. if __name__ == '__main__':
  40. unittest.main()