12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from federatedml.param.base_param import BaseParam
- from federatedml.util import LOGGER
- class LabelTransformParam(BaseParam):
- """
- Define label transform param that used in label transform.
- Parameters
- ----------
- label_encoder : None or dict, default : None
- Specify (label, encoded label) key-value pairs for transforming labels to new values.
- e.g. {"Yes": 1, "No": 0};
- **new in ver 1.9: during training, input labels not found in `label_encoder` will retain its original value
- label_list : None or list, default : None
- List all input labels, used for matching types of original keys in label_encoder dict,
- length should match key count in label_encoder, e.g. ["Yes", "No"];
- **new in ver 1.9: given non-emtpy `label_encoder`, when `label_list` not provided,
- module will inference label types from input data
- need_run: bool, default: True
- Specify whether to run label transform
- """
- def __init__(self, label_encoder=None, label_list=None, need_run=True):
- super(LabelTransformParam, self).__init__()
- self.label_encoder = label_encoder
- self.label_list = label_list
- self.need_run = need_run
- def check(self):
- model_param_descr = "label transform param's "
- BaseParam.check_boolean(self.need_run, f"{model_param_descr} need run ")
- if self.label_encoder is not None:
- if not isinstance(self.label_encoder, dict):
- raise ValueError(f"{model_param_descr} label_encoder should be dict type")
- if len(self.label_encoder) == 0:
- self.label_encoder = None
- if self.label_list is not None:
- if not isinstance(self.label_list, list):
- raise ValueError(f"{model_param_descr} label_list should be list type")
- if self.label_encoder and self.label_list and len(self.label_list) != len(self.label_encoder.keys()):
- raise ValueError(f"label_list's length not matching label_encoder key count.")
- if len(self.label_list) == 0:
- self.label_list = None
- LOGGER.debug("Finish label transformer parameter check!")
- return True
|