node_test.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import unittest
  17. from federatedml.ensemble import Node, SplitInfo
  18. class TestNode(unittest.TestCase):
  19. def setUp(self):
  20. pass
  21. def test_node(self):
  22. param_dict = {"id": 5, "sitename": "test", "fid": 55, "bid": 555,
  23. "weight": -1, "is_leaf": True, "sum_grad": 2, "sum_hess": 3,
  24. "left_nodeid": 6, "right_nodeid": 7}
  25. node = Node(id=5, sitename="test", fid=55, bid=555, weight=-1, is_leaf=True,
  26. sum_grad=2, sum_hess=3, left_nodeid=6, right_nodeid=7)
  27. for key in param_dict:
  28. self.assertTrue(param_dict[key] == getattr(node, key))
  29. class TestSplitInfo(unittest.TestCase):
  30. def setUp(self):
  31. pass
  32. def test_splitinfo(self):
  33. pass
  34. param_dict = {"sitename": "testsplitinfo",
  35. "best_fid": 23, "best_bid": 233,
  36. "sum_grad": 2333, "sum_hess": 23333, "gain": 233333}
  37. splitinfo = SplitInfo(sitename="testsplitinfo", best_fid=23, best_bid=233,
  38. sum_grad=2333, sum_hess=23333, gain=233333)
  39. for key in param_dict:
  40. self.assertTrue(param_dict[key] == getattr(splitinfo, key))
  41. if __name__ == '__main__':
  42. unittest.main()