#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import builtins import json import os from federatedml.util import LOGGER, consts _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" _USER_FEEDED_PARAMS = "_user_feeded_params" _IS_RAW_CONF = "_is_raw_conf" def deprecated_param(*names): def _decorator(cls: "BaseParam"): deprecated = cls._get_or_init_deprecated_params_set() for name in names: deprecated.add(name) return cls return _decorator class _StaticDefaultMeta(type): """ hook object creation, copy all default parameters in `__init__` """ def __call__(cls, *args, **kwargs): obj = cls.__new__(cls) import inspect import copy signature = inspect.signature(obj.__init__).bind(*args, **kwargs) signature.apply_defaults() args = copy.deepcopy(signature.args) kwargs = copy.deepcopy(signature.kwargs) obj.__init__(*args, **kwargs) return obj class BaseParam(metaclass=_StaticDefaultMeta): def __init__(self): pass def set_name(self, name: str): self._name = name return self def check(self): raise NotImplementedError("Parameter Object should be checked.") @classmethod def _get_or_init_deprecated_params_set(cls): if not hasattr(cls, _DEPRECATED_PARAMS): setattr(cls, _DEPRECATED_PARAMS, set()) return getattr(cls, _DEPRECATED_PARAMS) def _get_or_init_feeded_deprecated_params_set(self, conf=None): if not hasattr(self, _FEEDED_DEPRECATED_PARAMS): if conf is None: setattr(self, _FEEDED_DEPRECATED_PARAMS, set()) else: setattr( self, _FEEDED_DEPRECATED_PARAMS, set(conf[_FEEDED_DEPRECATED_PARAMS]), ) return getattr(self, _FEEDED_DEPRECATED_PARAMS) def _get_or_init_user_feeded_params_set(self, conf=None): if not hasattr(self, _USER_FEEDED_PARAMS): if conf is None: setattr(self, _USER_FEEDED_PARAMS, set()) else: setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS])) return getattr(self, _USER_FEEDED_PARAMS) def get_user_feeded(self): return self._get_or_init_user_feeded_params_set() def get_feeded_deprecated_params(self): return self._get_or_init_feeded_deprecated_params_set() @property def _deprecated_params_set(self): return {name: True for name in self.get_feeded_deprecated_params()} def as_dict(self): def _recursive_convert_obj_to_dict(obj): ret_dict = {} for attr_name in list(obj.__dict__): # get attr attr = getattr(obj, attr_name) if attr and type(attr).__name__ not in dir(builtins): ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr) else: ret_dict[attr_name] = attr return ret_dict return _recursive_convert_obj_to_dict(self) def update(self, conf, allow_redundant=False): update_from_raw_conf = conf.get(_IS_RAW_CONF, True) if update_from_raw_conf: deprecated_params_set = self._get_or_init_deprecated_params_set() feeded_deprecated_params_set = ( self._get_or_init_feeded_deprecated_params_set() ) user_feeded_params_set = self._get_or_init_user_feeded_params_set() setattr(self, _IS_RAW_CONF, False) else: feeded_deprecated_params_set = ( self._get_or_init_feeded_deprecated_params_set(conf) ) user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf) def _recursive_update_param(param, config, depth, prefix): if depth > consts.PARAM_MAXDEPTH: raise ValueError("Param define nesting too deep!!!, can not parse it") inst_variables = param.__dict__ redundant_attrs = [] for config_key, config_value in config.items(): # redundant attr if config_key not in inst_variables: if not update_from_raw_conf and config_key.startswith("_"): setattr(param, config_key, config_value) else: redundant_attrs.append(config_key) continue full_config_key = f"{prefix}{config_key}" if update_from_raw_conf: # add user feeded params user_feeded_params_set.add(full_config_key) # update user feeded deprecated param set if full_config_key in deprecated_params_set: feeded_deprecated_params_set.add(full_config_key) # supported attr attr = getattr(param, config_key) if type(attr).__name__ in dir(builtins) or attr is None: setattr(param, config_key, config_value) else: # recursive set obj attr sub_params = _recursive_update_param( attr, config_value, depth + 1, prefix=f"{prefix}{config_key}." ) setattr(param, config_key, sub_params) if not allow_redundant and redundant_attrs: raise ValueError( f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`" ) return param return _recursive_update_param(param=self, config=conf, depth=0, prefix="") def extract_not_builtin(self): def _get_not_builtin_types(obj): ret_dict = {} for variable in obj.__dict__: attr = getattr(obj, variable) if attr and type(attr).__name__ not in dir(builtins): ret_dict[variable] = _get_not_builtin_types(attr) return ret_dict return _get_not_builtin_types(self) def validate(self): self.builtin_types = dir(builtins) self.func = { "ge": self._greater_equal_than, "le": self._less_equal_than, "in": self._in, "not_in": self._not_in, "range": self._range, } home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__))) param_validation_path_prefix = home_dir + "/param_validation/" param_name = type(self).__name__ param_validation_path = "/".join( [param_validation_path_prefix, param_name + ".json"] ) validation_json = None try: with open(param_validation_path, "r") as fin: validation_json = json.loads(fin.read()) except BaseException: return self._validate_param(self, validation_json) def _validate_param(self, param_obj, validation_json): default_section = type(param_obj).__name__ var_list = param_obj.__dict__ for variable in var_list: attr = getattr(param_obj, variable) if type(attr).__name__ in self.builtin_types or attr is None: if variable not in validation_json: continue validation_dict = validation_json[default_section][variable] value = getattr(param_obj, variable) value_legal = False for op_type in validation_dict: if self.func[op_type](value, validation_dict[op_type]): value_legal = True break if not value_legal: raise ValueError( "Plase check runtime conf, {} = {} does not match user-parameter restriction".format( variable, value ) ) elif variable in validation_json: self._validate_param(attr, validation_json) @staticmethod def check_string(param, descr): if type(param).__name__ not in ["str"]: raise ValueError( descr + " {} not supported, should be string type".format(param) ) @staticmethod def check_positive_integer(param, descr): if type(param).__name__ not in ["int", "long"] or param <= 0: raise ValueError( descr + " {} not supported, should be positive integer".format(param) ) @staticmethod def check_positive_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param <= 0: raise ValueError( descr + " {} not supported, should be positive numeric".format(param) ) @staticmethod def check_nonnegative_number(param, descr): if type(param).__name__ not in ["float", "int", "long"] or param < 0: raise ValueError( descr + " {} not supported, should be non-negative numeric".format(param) ) @staticmethod def check_decimal_float(param, descr): if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1: raise ValueError( descr + " {} not supported, should be a float number in range [0, 1]".format( param ) ) @staticmethod def check_boolean(param, descr): if type(param).__name__ != "bool": raise ValueError( descr + " {} not supported, should be bool type".format(param) ) @staticmethod def check_open_unit_interval(param, descr): if type(param).__name__ not in ["float"] or param <= 0 or param >= 1: raise ValueError( descr + " should be a numeric number between 0 and 1 exclusively" ) @staticmethod def check_valid_value(param, descr, valid_values): if param not in valid_values: raise ValueError( descr + " {} is not supported, it should be in {}".format(param, valid_values) ) @staticmethod def check_defined_type(param, descr, types): if type(param).__name__ not in types: raise ValueError( descr + " {} not supported, should be one of {}".format(param, types) ) @staticmethod def check_and_change_lower(param, valid_list, descr=""): if type(param).__name__ != "str": raise ValueError( descr + " {} not supported, should be one of {}".format(param, valid_list) ) lower_param = param.lower() if lower_param in valid_list: return lower_param else: raise ValueError( descr + " {} not supported, should be one of {}".format(param, valid_list) ) @staticmethod def _greater_equal_than(value, limit): return value >= limit - consts.FLOAT_ZERO @staticmethod def _less_equal_than(value, limit): return value <= limit + consts.FLOAT_ZERO @staticmethod def _range(value, ranges): in_range = False for left_limit, right_limit in ranges: if ( left_limit - consts.FLOAT_ZERO <= value <= right_limit + consts.FLOAT_ZERO ): in_range = True break return in_range @staticmethod def _in(value, right_value_list): return value in right_value_list @staticmethod def _not_in(value, wrong_value_list): return value not in wrong_value_list def _warn_deprecated_param(self, param_name, descr): if self._deprecated_params_set.get(param_name): LOGGER.warning( f"{descr} {param_name} is deprecated and ignored in this version." ) def _warn_to_deprecate_param(self, param_name, descr, new_param): if self._deprecated_params_set.get(param_name): LOGGER.warning( f"{descr} {param_name} will be deprecated in future release; " f"please use {new_param} instead." ) return True return False