label_transform_test.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import unittest
  17. import uuid
  18. import numpy as np
  19. from fate_arch.session import computing_session as session
  20. from federatedml.feature.instance import Instance
  21. from federatedml.statistic.data_overview import predict_detail_dict_to_str
  22. from federatedml.util.label_transform import LabelTransformer
  23. class TestLabelTransform(unittest.TestCase):
  24. def setUp(self):
  25. session.init("test_label_transform_" + str(uuid.uuid1()))
  26. self.label_encoder = {"yes": 1, "no": 0}
  27. self.predict_label_encoder = {1: "yes", 0: "no"}
  28. data = []
  29. for i in range(1, 11):
  30. label = "yes" if i % 5 == 0 else "no"
  31. instance = Instance(inst_id=i, features=np.random.random(3), label=label)
  32. data.append((i, instance))
  33. schema = {"header": ["x0", "x1", "x2"],
  34. "sid": "id", "label_name": "y"}
  35. self.table = session.parallelize(data, include_key=True, partition=8)
  36. self.table.schema = schema
  37. self.label_transformer_obj = LabelTransformer()
  38. def test_get_label_encoder(self):
  39. self.label_transformer_obj.update_label_encoder(self.table)
  40. c_label_encoder = {"yes": 1, "no": 0}
  41. self.assertDictEqual(self.label_transformer_obj.label_encoder, c_label_encoder)
  42. def test_replace_instance_label(self):
  43. instance = self.table.first()[1]
  44. replaced_instance = self.label_transformer_obj.replace_instance_label(instance, self.label_encoder)
  45. self.assertEqual(replaced_instance.label, self.label_encoder[instance.label])
  46. def test_transform_data_label(self):
  47. replaced_data = self.label_transformer_obj.transform_data_label(self.table, self.label_encoder)
  48. replaced_data.join(self.table, lambda x, y: self.assertEqual(x.label, self.label_encoder[y.label]))
  49. def test_replace_predict_label(self):
  50. true_label, predict_label, predict_score, predict_detail, predict_type = 1, 0, 0.1, {
  51. "1": 0.1, "0": 0.9}, "train"
  52. predict_detail = predict_detail_dict_to_str(predict_detail)
  53. predict_result = Instance(inst_id=0,
  54. features=[true_label, predict_label, predict_score, predict_detail, predict_type])
  55. r_predict_instance = self.label_transformer_obj.replace_predict_label(
  56. predict_result, self.predict_label_encoder)
  57. r_predict_result = r_predict_instance.features
  58. c_predict_detail = predict_detail_dict_to_str({"yes": 0.1, "no": 0.9})
  59. c_predict_result = ["yes", "no", predict_score, c_predict_detail, predict_type]
  60. self.assertEqual(r_predict_result, c_predict_result)
  61. if __name__ == '__main__':
  62. unittest.main()