label_transform_param.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.param.base_param import BaseParam
  19. from federatedml.util import LOGGER
  20. class LabelTransformParam(BaseParam):
  21. """
  22. Define label transform param that used in label transform.
  23. Parameters
  24. ----------
  25. label_encoder : None or dict, default : None
  26. Specify (label, encoded label) key-value pairs for transforming labels to new values.
  27. e.g. {"Yes": 1, "No": 0};
  28. **new in ver 1.9: during training, input labels not found in `label_encoder` will retain its original value
  29. label_list : None or list, default : None
  30. List all input labels, used for matching types of original keys in label_encoder dict,
  31. length should match key count in label_encoder, e.g. ["Yes", "No"];
  32. **new in ver 1.9: given non-emtpy `label_encoder`, when `label_list` not provided,
  33. module will inference label types from input data
  34. need_run: bool, default: True
  35. Specify whether to run label transform
  36. """
  37. def __init__(self, label_encoder=None, label_list=None, need_run=True):
  38. super(LabelTransformParam, self).__init__()
  39. self.label_encoder = label_encoder
  40. self.label_list = label_list
  41. self.need_run = need_run
  42. def check(self):
  43. model_param_descr = "label transform param's "
  44. BaseParam.check_boolean(self.need_run, f"{model_param_descr} need run ")
  45. if self.label_encoder is not None:
  46. if not isinstance(self.label_encoder, dict):
  47. raise ValueError(f"{model_param_descr} label_encoder should be dict type")
  48. if len(self.label_encoder) == 0:
  49. self.label_encoder = None
  50. if self.label_list is not None:
  51. if not isinstance(self.label_list, list):
  52. raise ValueError(f"{model_param_descr} label_list should be list type")
  53. if self.label_encoder and self.label_list and len(self.label_list) != len(self.label_encoder.keys()):
  54. raise ValueError(f"label_list's length not matching label_encoder key count.")
  55. if len(self.label_list) == 0:
  56. self.label_list = None
  57. LOGGER.debug("Finish label transformer parameter check!")
  58. return True