1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import time
- import unittest
- import uuid
- import numpy as np
- from fate_arch.session import computing_session as session
- from federatedml.feature.instance import Instance
- from federatedml.util.sample_weight import SampleWeight
- class TestSampleWeight(unittest.TestCase):
- def setUp(self):
- session.init("test_sample_weight_" + str(uuid.uuid1()))
- self.class_weight = {"0": 2, "1": 3}
- data = []
- for i in range(1, 11):
- label = 1 if i % 5 == 0 else 0
- instance = Instance(inst_id=i, features=np.random.random(3), label=label)
- data.append((i, instance))
- schema = {"header": ["x0", "x1", "x2"],
- "sid": "id", "label_name": "y"}
- self.table = session.parallelize(data, include_key=True, partition=8)
- self.table.schema = schema
- self.sample_weight_obj = SampleWeight()
- def test_get_class_weight(self):
- class_weight = self.sample_weight_obj.get_class_weight(self.table)
- c_class_weight = {"1": 10 / 4, "0": 10 / 16}
- self.assertDictEqual(class_weight, c_class_weight)
- def test_replace_weight(self):
- instance = self.table.first()
- weighted_instance = self.sample_weight_obj.replace_weight(instance[1], self.class_weight)
- self.assertEqual(weighted_instance.weight, self.class_weight[str(weighted_instance.label)])
- def test_assign_sample_weight(self):
- weighted_table = self.sample_weight_obj.assign_sample_weight(self.table, self.class_weight, None, False)
- weighted_table.mapValues(lambda v: self.assertEqual(v.weight, self.class_weight[str(v.label)]))
- def test_get_weight_loc(self):
- c_loc = 2
- loc = self.sample_weight_obj.get_weight_loc(self.table, "x2")
- self.assertEqual(loc, c_loc)
- def tearDown(self):
- session.stop()
- if __name__ == '__main__':
- unittest.main()
|