label_transform_param.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from pipeline.param.base_param import BaseParam
  19. class LabelTransformParam(BaseParam):
  20. """
  21. Define label transform param that used in label transform.
  22. Parameters
  23. ----------
  24. label_encoder : None or dict, default : None
  25. Specify (label, encoded label) key-value pairs for transforming labels to new values.
  26. e.g. {"Yes": 1, "No": 0};
  27. **new in ver 1.9: during training, input labels not found in `label_encoder` will retain its original value
  28. label_list : None or list, default : None
  29. List all input labels, used for matching types of original keys in label_encoder dict,
  30. length should match key count in label_encoder, e.g. ["Yes", "No"];
  31. **new in ver 1.9: given non-emtpy `label_encoder`, when `label_list` not provided,
  32. module will inference label types from input data
  33. need_run: bool, default: True
  34. Specify whether to run label transform
  35. """
  36. def __init__(self, label_encoder=None, label_list=None, need_run=True):
  37. super(LabelTransformParam, self).__init__()
  38. self.label_encoder = label_encoder
  39. self.label_list = label_list
  40. self.need_run = need_run
  41. def check(self):
  42. model_param_descr = "label transform param's "
  43. BaseParam.check_boolean(self.need_run, f"{model_param_descr} need run ")
  44. if self.label_encoder is not None:
  45. if not isinstance(self.label_encoder, dict):
  46. raise ValueError(f"{model_param_descr} label_encoder should be dict type")
  47. if len(self.label_encoder) == 0:
  48. self.label_encoder = None
  49. if self.label_list is not None:
  50. if not isinstance(self.label_list, list):
  51. raise ValueError(f"{model_param_descr} label_list should be list type")
  52. if self.label_encoder and self.label_list and len(self.label_list) != len(self.label_encoder.keys()):
  53. raise ValueError(f"label_list's length not matching label_encoder key count.")
  54. if len(self.label_list) == 0:
  55. self.label_list = None
  56. return True