base_param.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import builtins
  19. import json
  20. import os
  21. from pipeline.param import consts
  22. class BaseParam(object):
  23. def __init__(self):
  24. pass
  25. def check(self):
  26. raise NotImplementedError("Parameter Object should be checked.")
  27. def validate(self):
  28. self.builtin_types = dir(builtins)
  29. self.func = {"ge": self._greater_equal_than,
  30. "le": self._less_equal_than,
  31. "in": self._in,
  32. "not_in": self._not_in,
  33. "range": self._range
  34. }
  35. home_dir = os.path.abspath(os.path.dirname(os.path.realpath(__file__)))
  36. param_validation_path_prefix = home_dir + "/param_validation/"
  37. param_name = type(self).__name__
  38. param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"])
  39. validation_json = None
  40. print("param validation path is {}".format(home_dir))
  41. try:
  42. with open(param_validation_path, "r") as fin:
  43. validation_json = json.loads(fin.read())
  44. except BaseException:
  45. return
  46. self._validate_param(self, validation_json)
  47. def _validate_param(self, param_obj, validation_json):
  48. default_section = type(param_obj).__name__
  49. var_list = param_obj.__dict__
  50. for variable in var_list:
  51. attr = getattr(param_obj, variable)
  52. if type(attr).__name__ in self.builtin_types or attr is None:
  53. if variable not in validation_json:
  54. continue
  55. validation_dict = validation_json[default_section][variable]
  56. value = getattr(param_obj, variable)
  57. value_legal = False
  58. for op_type in validation_dict:
  59. if self.func[op_type](value, validation_dict[op_type]):
  60. value_legal = True
  61. break
  62. if not value_legal:
  63. raise ValueError(
  64. "Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
  65. variable, value))
  66. elif variable in validation_json:
  67. self._validate_param(attr, validation_json)
  68. @staticmethod
  69. def check_string(param, descr):
  70. if type(param).__name__ not in ["str"]:
  71. raise ValueError(descr + " {} not supported, should be string type".format(param))
  72. @staticmethod
  73. def check_positive_integer(param, descr):
  74. if type(param).__name__ not in ["int", "long"] or param <= 0:
  75. raise ValueError(descr + " {} not supported, should be positive integer".format(param))
  76. @staticmethod
  77. def check_positive_number(param, descr):
  78. if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
  79. raise ValueError(descr + " {} not supported, should be positive numeric".format(param))
  80. @staticmethod
  81. def check_nonnegative_number(param, descr):
  82. if type(param).__name__ not in ["float", "int", "long"] or param < 0:
  83. raise ValueError(descr + " {} not supported, should be non-negative numeric".format(param))
  84. @staticmethod
  85. def check_decimal_float(param, descr):
  86. if type(param).__name__ not in ["float"] or param < 0 or param > 1:
  87. raise ValueError(descr + " {} not supported, should be a float number in range [0, 1]".format(param))
  88. @staticmethod
  89. def check_boolean(param, descr):
  90. if type(param).__name__ != "bool":
  91. raise ValueError(descr + " {} not supported, should be bool type".format(param))
  92. @staticmethod
  93. def check_open_unit_interval(param, descr):
  94. if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
  95. raise ValueError(descr + " should be a numeric number between 0 and 1 exclusively")
  96. @staticmethod
  97. def check_valid_value(param, descr, valid_values):
  98. if param not in valid_values:
  99. raise ValueError(descr + " {} is not supported, it should be in {}".format(param, valid_values))
  100. @staticmethod
  101. def check_defined_type(param, descr, types):
  102. if type(param).__name__ not in types:
  103. raise ValueError(descr + " {} not supported, should be one of {}".format(param, types))
  104. @staticmethod
  105. def check_and_change_lower(param, valid_list, descr=''):
  106. if type(param).__name__ != 'str':
  107. raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
  108. lower_param = param.lower()
  109. if lower_param in valid_list:
  110. return lower_param
  111. else:
  112. raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
  113. @staticmethod
  114. def _greater_equal_than(value, limit):
  115. return value >= limit - consts.FLOAT_ZERO
  116. @staticmethod
  117. def _less_equal_than(value, limit):
  118. return value <= limit + consts.FLOAT_ZERO
  119. @staticmethod
  120. def _range(value, ranges):
  121. in_range = False
  122. for left_limit, right_limit in ranges:
  123. if left_limit - consts.FLOAT_ZERO <= value <= right_limit + consts.FLOAT_ZERO:
  124. in_range = True
  125. break
  126. return in_range
  127. @staticmethod
  128. def _in(value, right_value_list):
  129. return value in right_value_list
  130. @staticmethod
  131. def _not_in(value, wrong_value_list):
  132. return value not in wrong_value_list