one_hot_test.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import unittest
  18. from fate_arch.session import computing_session as session
  19. session.init("123")
  20. from federatedml.feature.one_hot_encoder import OneHotEncoder
  21. from federatedml.feature.instance import Instance
  22. from federatedml.util.anonymous_generator_util import Anonymous
  23. import numpy as np
  24. class TestOneHotEncoder(unittest.TestCase):
  25. def setUp(self):
  26. self.data_num = 1000
  27. self.feature_num = 3
  28. self.cols = [0, 1, 2, 3]
  29. self.header = ['x' + str(i) for i in range(self.feature_num)]
  30. self.anonymous_header = ["guest_9999_x" + str(i) for i in range(self.feature_num)]
  31. final_result = []
  32. for i in range(self.data_num):
  33. tmp = []
  34. for _ in range(self.feature_num):
  35. tmp.append(np.random.choice([1, 2, 3, 'test_str']))
  36. tmp = np.array(tmp)
  37. inst = Instance(inst_id=i, features=tmp, label=0)
  38. tmp_pair = (str(i), inst)
  39. final_result.append(tmp_pair)
  40. table = session.parallelize(final_result,
  41. include_key=True,
  42. partition=10)
  43. table.schema = {"header": self.header,
  44. "anonymous_header": self.anonymous_header}
  45. self.model_name = 'OneHotEncoder'
  46. self.table = table
  47. self.args = {"data": {self.model_name: {"data": table}}}
  48. def test_instance(self):
  49. one_hot_encoder = OneHotEncoder()
  50. one_hot_encoder.anonymous_generator = Anonymous()
  51. one_hot_encoder.cols = self.cols
  52. one_hot_encoder.cols_index = self.cols
  53. result = one_hot_encoder.fit(self.table)
  54. local_result = result.collect()
  55. for k, v in local_result:
  56. new_features = v.features
  57. self.assertTrue(len(new_features) == self.feature_num * len(self.cols))
  58. if __name__ == '__main__':
  59. unittest.main()