stepwise_test.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import numpy as np
  17. import unittest
  18. import uuid
  19. from fate_arch.common import profile
  20. from fate_arch.session import computing_session as session
  21. from federatedml.model_selection.stepwise.hetero_stepwise import HeteroStepwise
  22. from federatedml.util import consts
  23. profile._PROFILE_LOG_ENABLED = False
  24. class TestStepwise(unittest.TestCase):
  25. def setUp(self):
  26. self.job_id = str(uuid.uuid1())
  27. session.init("test_random_sampler_" + self.job_id)
  28. model = HeteroStepwise()
  29. model.__setattr__('role', consts.GUEST)
  30. model.__setattr__('fit_intercept', True)
  31. self.model = model
  32. data_num = 100
  33. feature_num = 5
  34. bool_list = [True, False, True, True, False]
  35. self.str_mask = "10110"
  36. self.header = ["x1", "x2", "x3", "x4", "x5"]
  37. self.mask = self.prepare_mask(bool_list)
  38. def prepare_mask(self, bool_list):
  39. mask = np.array(bool_list, dtype=bool)
  40. return mask
  41. def test_get_dfe(self):
  42. real_dfe = 4
  43. dfe = HeteroStepwise.get_dfe(self.model, self.str_mask)
  44. self.assertEqual(dfe, real_dfe)
  45. def test_drop_one(self):
  46. real_masks = [np.array([0, 0, 1, 1, 0], dtype=bool), np.array([1, 0, 0, 1, 0], dtype=bool),
  47. np.array([1, 0, 1, 0, 0], dtype=bool)]
  48. mask_generator = HeteroStepwise.drop_one(self.mask)
  49. i = 0
  50. for mask in mask_generator:
  51. np.testing.assert_array_equal(
  52. mask,
  53. real_masks[i],
  54. f"In stepwise_test drop one: mask{mask} not equal to expected {real_masks[i]}")
  55. i += 1
  56. def test_add_one(self):
  57. real_masks = [np.array([1, 1, 1, 1, 0], dtype=bool), np.array([1, 0, 1, 1, 1], dtype=bool)]
  58. mask_generator = HeteroStepwise.add_one(self.mask)
  59. i = 0
  60. for mask in mask_generator:
  61. np.testing.assert_array_equal(mask, real_masks[i],
  62. f"In stepwise_test add one: mask{mask} not equal to expected {real_masks[i]}")
  63. i += 1
  64. def test_mask2string(self):
  65. real_str_mask = "1011010110"
  66. str_mask = HeteroStepwise.mask2string(self.mask, self.mask)
  67. self.assertTrue(str_mask == real_str_mask)
  68. def test_string2mask(self):
  69. real_mask = np.array([1, 0, 1, 1, 0], dtype=bool)
  70. mask = HeteroStepwise.string2mask(self.str_mask)
  71. np.testing.assert_array_equal(mask, real_mask)
  72. def test_get_to_enter(self):
  73. real_to_enter = ["x2", "x5"]
  74. to_enter = self.model.get_to_enter(self.mask, self.mask, self.header)
  75. self.assertListEqual(to_enter, real_to_enter)
  76. def tearDown(self):
  77. session.stop()
  78. if __name__ == '__main__':
  79. unittest.main()