base_param.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import builtins
  19. import json
  20. import os
  21. from federatedml.util import LOGGER, consts
  22. _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
  23. _DEPRECATED_PARAMS = "_deprecated_params"
  24. _USER_FEEDED_PARAMS = "_user_feeded_params"
  25. _IS_RAW_CONF = "_is_raw_conf"
  26. def deprecated_param(*names):
  27. def _decorator(cls: "BaseParam"):
  28. deprecated = cls._get_or_init_deprecated_params_set()
  29. for name in names:
  30. deprecated.add(name)
  31. return cls
  32. return _decorator
  33. class _StaticDefaultMeta(type):
  34. """
  35. hook object creation, copy all default parameters in `__init__`
  36. """
  37. def __call__(cls, *args, **kwargs):
  38. obj = cls.__new__(cls)
  39. import inspect
  40. import copy
  41. signature = inspect.signature(obj.__init__).bind(*args, **kwargs)
  42. signature.apply_defaults()
  43. args = copy.deepcopy(signature.args)
  44. kwargs = copy.deepcopy(signature.kwargs)
  45. obj.__init__(*args, **kwargs)
  46. return obj
  47. class BaseParam(metaclass=_StaticDefaultMeta):
  48. def __init__(self):
  49. pass
  50. def set_name(self, name: str):
  51. self._name = name
  52. return self
  53. def check(self):
  54. raise NotImplementedError("Parameter Object should be checked.")
  55. @classmethod
  56. def _get_or_init_deprecated_params_set(cls):
  57. if not hasattr(cls, _DEPRECATED_PARAMS):
  58. setattr(cls, _DEPRECATED_PARAMS, set())
  59. return getattr(cls, _DEPRECATED_PARAMS)
  60. def _get_or_init_feeded_deprecated_params_set(self, conf=None):
  61. if not hasattr(self, _FEEDED_DEPRECATED_PARAMS):
  62. if conf is None:
  63. setattr(self, _FEEDED_DEPRECATED_PARAMS, set())
  64. else:
  65. setattr(
  66. self,
  67. _FEEDED_DEPRECATED_PARAMS,
  68. set(conf[_FEEDED_DEPRECATED_PARAMS]),
  69. )
  70. return getattr(self, _FEEDED_DEPRECATED_PARAMS)
  71. def _get_or_init_user_feeded_params_set(self, conf=None):
  72. if not hasattr(self, _USER_FEEDED_PARAMS):
  73. if conf is None:
  74. setattr(self, _USER_FEEDED_PARAMS, set())
  75. else:
  76. setattr(self, _USER_FEEDED_PARAMS, set(conf[_USER_FEEDED_PARAMS]))
  77. return getattr(self, _USER_FEEDED_PARAMS)
  78. def get_user_feeded(self):
  79. return self._get_or_init_user_feeded_params_set()
  80. def get_feeded_deprecated_params(self):
  81. return self._get_or_init_feeded_deprecated_params_set()
  82. @property
  83. def _deprecated_params_set(self):
  84. return {name: True for name in self.get_feeded_deprecated_params()}
  85. def as_dict(self):
  86. def _recursive_convert_obj_to_dict(obj):
  87. ret_dict = {}
  88. for attr_name in list(obj.__dict__):
  89. # get attr
  90. attr = getattr(obj, attr_name)
  91. if attr and type(attr).__name__ not in dir(builtins):
  92. ret_dict[attr_name] = _recursive_convert_obj_to_dict(attr)
  93. else:
  94. ret_dict[attr_name] = attr
  95. return ret_dict
  96. return _recursive_convert_obj_to_dict(self)
  97. def update(self, conf, allow_redundant=False):
  98. update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
  99. if update_from_raw_conf:
  100. deprecated_params_set = self._get_or_init_deprecated_params_set()
  101. feeded_deprecated_params_set = (
  102. self._get_or_init_feeded_deprecated_params_set()
  103. )
  104. user_feeded_params_set = self._get_or_init_user_feeded_params_set()
  105. setattr(self, _IS_RAW_CONF, False)
  106. else:
  107. feeded_deprecated_params_set = (
  108. self._get_or_init_feeded_deprecated_params_set(conf)
  109. )
  110. user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
  111. def _recursive_update_param(param, config, depth, prefix):
  112. if depth > consts.PARAM_MAXDEPTH:
  113. raise ValueError("Param define nesting too deep!!!, can not parse it")
  114. inst_variables = param.__dict__
  115. redundant_attrs = []
  116. for config_key, config_value in config.items():
  117. # redundant attr
  118. if config_key not in inst_variables:
  119. if not update_from_raw_conf and config_key.startswith("_"):
  120. setattr(param, config_key, config_value)
  121. else:
  122. redundant_attrs.append(config_key)
  123. continue
  124. full_config_key = f"{prefix}{config_key}"
  125. if update_from_raw_conf:
  126. # add user feeded params
  127. user_feeded_params_set.add(full_config_key)
  128. # update user feeded deprecated param set
  129. if full_config_key in deprecated_params_set:
  130. feeded_deprecated_params_set.add(full_config_key)
  131. # supported attr
  132. attr = getattr(param, config_key)
  133. if type(attr).__name__ in dir(builtins) or attr is None:
  134. setattr(param, config_key, config_value)
  135. else:
  136. # recursive set obj attr
  137. sub_params = _recursive_update_param(
  138. attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
  139. )
  140. setattr(param, config_key, sub_params)
  141. if not allow_redundant and redundant_attrs:
  142. raise ValueError(
  143. f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
  144. )
  145. return param
  146. return _recursive_update_param(param=self, config=conf, depth=0, prefix="")
  147. def extract_not_builtin(self):
  148. def _get_not_builtin_types(obj):
  149. ret_dict = {}
  150. for variable in obj.__dict__:
  151. attr = getattr(obj, variable)
  152. if attr and type(attr).__name__ not in dir(builtins):
  153. ret_dict[variable] = _get_not_builtin_types(attr)
  154. return ret_dict
  155. return _get_not_builtin_types(self)
  156. def validate(self):
  157. self.builtin_types = dir(builtins)
  158. self.func = {
  159. "ge": self._greater_equal_than,
  160. "le": self._less_equal_than,
  161. "in": self._in,
  162. "not_in": self._not_in,
  163. "range": self._range,
  164. }
  165. home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
  166. param_validation_path_prefix = home_dir + "/param_validation/"
  167. param_name = type(self).__name__
  168. param_validation_path = "/".join(
  169. [param_validation_path_prefix, param_name + ".json"]
  170. )
  171. validation_json = None
  172. try:
  173. with open(param_validation_path, "r") as fin:
  174. validation_json = json.loads(fin.read())
  175. except BaseException:
  176. return
  177. self._validate_param(self, validation_json)
  178. def _validate_param(self, param_obj, validation_json):
  179. default_section = type(param_obj).__name__
  180. var_list = param_obj.__dict__
  181. for variable in var_list:
  182. attr = getattr(param_obj, variable)
  183. if type(attr).__name__ in self.builtin_types or attr is None:
  184. if variable not in validation_json:
  185. continue
  186. validation_dict = validation_json[default_section][variable]
  187. value = getattr(param_obj, variable)
  188. value_legal = False
  189. for op_type in validation_dict:
  190. if self.func[op_type](value, validation_dict[op_type]):
  191. value_legal = True
  192. break
  193. if not value_legal:
  194. raise ValueError(
  195. "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
  196. variable, value
  197. )
  198. )
  199. elif variable in validation_json:
  200. self._validate_param(attr, validation_json)
  201. @staticmethod
  202. def check_string(param, descr):
  203. if type(param).__name__ not in ["str"]:
  204. raise ValueError(
  205. descr + " {} not supported, should be string type".format(param)
  206. )
  207. @staticmethod
  208. def check_positive_integer(param, descr):
  209. if type(param).__name__ not in ["int", "long"] or param <= 0:
  210. raise ValueError(
  211. descr + " {} not supported, should be positive integer".format(param)
  212. )
  213. @staticmethod
  214. def check_positive_number(param, descr):
  215. if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
  216. raise ValueError(
  217. descr + " {} not supported, should be positive numeric".format(param)
  218. )
  219. @staticmethod
  220. def check_nonnegative_number(param, descr):
  221. if type(param).__name__ not in ["float", "int", "long"] or param < 0:
  222. raise ValueError(
  223. descr
  224. + " {} not supported, should be non-negative numeric".format(param)
  225. )
  226. @staticmethod
  227. def check_decimal_float(param, descr):
  228. if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
  229. raise ValueError(
  230. descr
  231. + " {} not supported, should be a float number in range [0, 1]".format(
  232. param
  233. )
  234. )
  235. @staticmethod
  236. def check_boolean(param, descr):
  237. if type(param).__name__ != "bool":
  238. raise ValueError(
  239. descr + " {} not supported, should be bool type".format(param)
  240. )
  241. @staticmethod
  242. def check_open_unit_interval(param, descr):
  243. if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
  244. raise ValueError(
  245. descr + " should be a numeric number between 0 and 1 exclusively"
  246. )
  247. @staticmethod
  248. def check_valid_value(param, descr, valid_values):
  249. if param not in valid_values:
  250. raise ValueError(
  251. descr
  252. + " {} is not supported, it should be in {}".format(param, valid_values)
  253. )
  254. @staticmethod
  255. def check_defined_type(param, descr, types):
  256. if type(param).__name__ not in types:
  257. raise ValueError(
  258. descr + " {} not supported, should be one of {}".format(param, types)
  259. )
  260. @staticmethod
  261. def check_and_change_lower(param, valid_list, descr=""):
  262. if type(param).__name__ != "str":
  263. raise ValueError(
  264. descr
  265. + " {} not supported, should be one of {}".format(param, valid_list)
  266. )
  267. lower_param = param.lower()
  268. if lower_param in valid_list:
  269. return lower_param
  270. else:
  271. raise ValueError(
  272. descr
  273. + " {} not supported, should be one of {}".format(param, valid_list)
  274. )
  275. @staticmethod
  276. def _greater_equal_than(value, limit):
  277. return value >= limit - consts.FLOAT_ZERO
  278. @staticmethod
  279. def _less_equal_than(value, limit):
  280. return value <= limit + consts.FLOAT_ZERO
  281. @staticmethod
  282. def _range(value, ranges):
  283. in_range = False
  284. for left_limit, right_limit in ranges:
  285. if (
  286. left_limit - consts.FLOAT_ZERO
  287. <= value
  288. <= right_limit + consts.FLOAT_ZERO
  289. ):
  290. in_range = True
  291. break
  292. return in_range
  293. @staticmethod
  294. def _in(value, right_value_list):
  295. return value in right_value_list
  296. @staticmethod
  297. def _not_in(value, wrong_value_list):
  298. return value not in wrong_value_list
  299. def _warn_deprecated_param(self, param_name, descr):
  300. if self._deprecated_params_set.get(param_name):
  301. LOGGER.warning(
  302. f"{descr} {param_name} is deprecated and ignored in this version."
  303. )
  304. def _warn_to_deprecate_param(self, param_name, descr, new_param):
  305. if self._deprecated_params_set.get(param_name):
  306. LOGGER.warning(
  307. f"{descr} {param_name} will be deprecated in future release; "
  308. f"please use {new_param} instead."
  309. )
  310. return True
  311. return False