sample_weight_test.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import time
  17. import unittest
  18. import uuid
  19. import numpy as np
  20. from fate_arch.session import computing_session as session
  21. from federatedml.feature.instance import Instance
  22. from federatedml.util.sample_weight import SampleWeight
  23. class TestSampleWeight(unittest.TestCase):
  24. def setUp(self):
  25. session.init("test_sample_weight_" + str(uuid.uuid1()))
  26. self.class_weight = {"0": 2, "1": 3}
  27. data = []
  28. for i in range(1, 11):
  29. label = 1 if i % 5 == 0 else 0
  30. instance = Instance(inst_id=i, features=np.random.random(3), label=label)
  31. data.append((i, instance))
  32. schema = {"header": ["x0", "x1", "x2"],
  33. "sid": "id", "label_name": "y"}
  34. self.table = session.parallelize(data, include_key=True, partition=8)
  35. self.table.schema = schema
  36. self.sample_weight_obj = SampleWeight()
  37. def test_get_class_weight(self):
  38. class_weight = self.sample_weight_obj.get_class_weight(self.table)
  39. c_class_weight = {"1": 10 / 4, "0": 10 / 16}
  40. self.assertDictEqual(class_weight, c_class_weight)
  41. def test_replace_weight(self):
  42. instance = self.table.first()
  43. weighted_instance = self.sample_weight_obj.replace_weight(instance[1], self.class_weight)
  44. self.assertEqual(weighted_instance.weight, self.class_weight[str(weighted_instance.label)])
  45. def test_assign_sample_weight(self):
  46. weighted_table = self.sample_weight_obj.assign_sample_weight(self.table, self.class_weight, None, False)
  47. weighted_table.mapValues(lambda v: self.assertEqual(v.weight, self.class_weight[str(v.label)]))
  48. def test_get_weight_loc(self):
  49. c_loc = 2
  50. loc = self.sample_weight_obj.get_weight_loc(self.table, "x2")
  51. self.assertEqual(loc, c_loc)
  52. def tearDown(self):
  53. session.stop()
  54. if __name__ == '__main__':
  55. unittest.main()