123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import builtins
- import json
- import os
- from federatedml.util import LOGGER, consts
- _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
- _DEPRECATED_PARAMS = "_deprecated_params"
- _USER_FEEDED_PARAMS = "_user_feeded_params"
- _IS_RAW_CONF = "_is_raw_conf"
- def deprecated_param(*names):
- def _decorator(cls: "BaseParam"):
- deprecated = cls._get_or_init_deprecated_params_set()
- for name in names:
- deprecated.add(name)
- return cls
- return _decorator
- class _StaticDefaultMeta(type):
- """
- hook object creation, copy all default parameters in `__init__`
- """
- def __call__(cls, *args, **kwargs):
- obj = cls.__new__(cls)
- import inspect
- import copy
- signature = inspect.signature(obj.__init__).bind(*args, **kwargs)
- signature.apply_defaults()
- args = copy.deepcopy(signature.args)
- kwargs = copy.deepcopy(signature.kwargs)
- obj.__init__(*args, **kwargs)
- return obj
- class BaseParam(metaclass=_StaticDefaultMeta):
- def __init__(self):
- pass
- def set_name(self, name: str):
- self._name = name
- return self
- def check(self):
- raise NotImplementedError("Parameter Object should be checked.")
- @classmethod
- def _get_or_init_deprecated_params_set(cls):
- if not hasattr(cls, _DEPRECATED_PARAMS):
- setattr(cls, _DEPRECATED_PARAMS, set())
- return getattr(cls, _DEPRECATED_PARAMS)
- def _get_or_init_feeded_deprecated_params_set(self, conf=None):
- if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
- if conf is None:
- setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
- else:
- setattr(
- self,
- _FEEDED_DEPRECATED_PARAMS,
- set(conf[_FEEDED_DEPRECATED_PARAMS]),
- )
- return getattr(self, _FEEDED_DEPRECATED_PARAMS)
- def _get_or_init_user_feeded_params_set(self, conf=None):
- if not hasattr(self, _USER_FEEDED_PARAMS):
- if conf is None:
- setattr(self, _USER_FEEDED_PARAMS, set())
- else:
- setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
- return getattr(self, _USER_FEEDED_PARAMS)
- def get_user_feeded(self):
- return self._get_or_init_user_feeded_params_set()
- def get_feeded_deprecated_params(self):
- return self._get_or_init_feeded_deprecated_params_set()
- @property
- def _deprecated_params_set(self):
- return {name: True for name in self.get_feeded_deprecated_params()}
- def as_dict(self):
- def _recursive_convert_obj_to_dict(obj):
- ret_dict = {}
- for attr_name in list(obj.__dict__):
- # get attr
- attr = getattr(obj, attr_name)
- if attr and type(attr).__name__ not in dir(builtins):
- ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
- else:
- ret_dict[attr_name] = attr
- return ret_dict
- return _recursive_convert_obj_to_dict(self)
- def update(self, conf, allow_redundant=False):
- update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
- if update_from_raw_conf:
- deprecated_params_set = self._get_or_init_deprecated_params_set()
- feeded_deprecated_params_set = (
- self._get_or_init_feeded_deprecated_params_set()
- )
- user_feeded_params_set = self._get_or_init_user_feeded_params_set()
- setattr(self, _IS_RAW_CONF, False)
- else:
- feeded_deprecated_params_set = (
- self._get_or_init_feeded_deprecated_params_set(conf)
- )
- user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
- def _recursive_update_param(param, config, depth, prefix):
- if depth > consts.PARAM_MAXDEPTH:
- raise ValueError("Param define nesting too deep!!!, can not parse it")
- inst_variables = param.__dict__
- redundant_attrs = []
- for config_key, config_value in config.items():
- # redundant attr
- if config_key not in inst_variables:
- if not update_from_raw_conf and config_key.startswith("_"):
- setattr(param, config_key, config_value)
- else:
- redundant_attrs.append(config_key)
- continue
- full_config_key = f"{prefix}{config_key}"
- if update_from_raw_conf:
- # add user feeded params
- user_feeded_params_set.add(full_config_key)
- # update user feeded deprecated param set
- if full_config_key in deprecated_params_set:
- feeded_deprecated_params_set.add(full_config_key)
- # supported attr
- attr = getattr(param, config_key)
- if type(attr).__name__ in dir(builtins) or attr is None:
- setattr(param, config_key, config_value)
- else:
- # recursive set obj attr
- sub_params = _recursive_update_param(
- attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
- )
- setattr(param, config_key, sub_params)
- if not allow_redundant and redundant_attrs:
- raise ValueError(
- f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
- )
- return param
- return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
- def extract_not_builtin(self):
- def _get_not_builtin_types(obj):
- ret_dict = {}
- for variable in obj.__dict__:
- attr = getattr(obj, variable)
- if attr and type(attr).__name__ not in dir(builtins):
- ret_dict[variable] = _get_not_builtin_types(attr)
- return ret_dict
- return _get_not_builtin_types(self)
- def validate(self):
- self.builtin_types = dir(builtins)
- self.func = {
- "ge": self._greater_equal_than,
- "le": self._less_equal_than,
- "in": self._in,
- "not_in": self._not_in,
- "range": self._range,
- }
- home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
- param_validation_path_prefix = home_dir + "/param_validation/"
- param_name = type(self).__name__
- param_validation_path = "/".join(
- [param_validation_path_prefix, param_name + ".json"]
- )
- validation_json = None
- try:
- with open(param_validation_path, "r") as fin:
- validation_json = json.loads(fin.read())
- except BaseException:
- return
- self._validate_param(self, validation_json)
- def _validate_param(self, param_obj, validation_json):
- default_section = type(param_obj).__name__
- var_list = param_obj.__dict__
- for variable in var_list:
- attr = getattr(param_obj, variable)
- if type(attr).__name__ in self.builtin_types or attr is None:
- if variable not in validation_json:
- continue
- validation_dict = validation_json[default_section][variable]
- value = getattr(param_obj, variable)
- value_legal = False
- for op_type in validation_dict:
- if self.func[op_type](value, validation_dict[op_type]):
- value_legal = True
- break
- if not value_legal:
- raise ValueError(
- "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
- variable, value
- )
- )
- elif variable in validation_json:
- self._validate_param(attr, validation_json)
- @staticmethod
- def check_string(param, descr):
- if type(param).__name__ not in ["str"]:
- raise ValueError(
- descr + " {} not supported, should be string type".format(param)
- )
- @staticmethod
- def check_positive_integer(param, descr):
- if type(param).__name__ not in ["int", "long"] or param <= 0:
- raise ValueError(
- descr + " {} not supported, should be positive integer".format(param)
- )
- @staticmethod
- def check_positive_number(param, descr):
- if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
- raise ValueError(
- descr + " {} not supported, should be positive numeric".format(param)
- )
- @staticmethod
- def check_nonnegative_number(param, descr):
- if type(param).__name__ not in ["float", "int", "long"] or param < 0:
- raise ValueError(
- descr
- + " {} not supported, should be non-negative numeric".format(param)
- )
- @staticmethod
- def check_decimal_float(param, descr):
- if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
- raise ValueError(
- descr
- + " {} not supported, should be a float number in range [0, 1]".format(
- param
- )
- )
- @staticmethod
- def check_boolean(param, descr):
- if type(param).__name__ != "bool":
- raise ValueError(
- descr + " {} not supported, should be bool type".format(param)
- )
- @staticmethod
- def check_open_unit_interval(param, descr):
- if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
- raise ValueError(
- descr + " should be a numeric number between 0 and 1 exclusively"
- )
- @staticmethod
- def check_valid_value(param, descr, valid_values):
- if param not in valid_values:
- raise ValueError(
- descr
- + " {} is not supported, it should be in {}".format(param, valid_values)
- )
- @staticmethod
- def check_defined_type(param, descr, types):
- if type(param).__name__ not in types:
- raise ValueError(
- descr + " {} not supported, should be one of {}".format(param, types)
- )
- @staticmethod
- def check_and_change_lower(param, valid_list, descr=""):
- if type(param).__name__ != "str":
- raise ValueError(
- descr
- + " {} not supported, should be one of {}".format(param, valid_list)
- )
- lower_param = param.lower()
- if lower_param in valid_list:
- return lower_param
- else:
- raise ValueError(
- descr
- + " {} not supported, should be one of {}".format(param, valid_list)
- )
- @staticmethod
- def _greater_equal_than(value, limit):
- return value >= limit - consts.FLOAT_ZERO
- @staticmethod
- def _less_equal_than(value, limit):
- return value <= limit + consts.FLOAT_ZERO
- @staticmethod
- def _range(value, ranges):
- in_range = False
- for left_limit, right_limit in ranges:
- if (
- left_limit - consts.FLOAT_ZERO
- <= value
- <= right_limit + consts.FLOAT_ZERO
- ):
- in_range = True
- break
- return in_range
- @staticmethod
- def _in(value, right_value_list):
- return value in right_value_list
- @staticmethod
- def _not_in(value, wrong_value_list):
- return value not in wrong_value_list
- def _warn_deprecated_param(self, param_name, descr):
- if self._deprecated_params_set.get(param_name):
- LOGGER.warning(
- f"{descr} {param_name} is deprecated and ignored in this version."
- )
- def _warn_to_deprecate_param(self, param_name, descr, new_param):
- if self._deprecated_params_set.get(param_name):
- LOGGER.warning(
- f"{descr} {param_name} will be deprecated in future release; "
- f"please use {new_param} instead."
- )
- return True
- return False
|