boosting_param.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.param.base_param import BaseParam, deprecated_param
  19. from federatedml.param.encrypt_param import EncryptParam
  20. from federatedml.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  21. from federatedml.param.cross_validation_param import CrossValidationParam
  22. from federatedml.param.predict_param import PredictParam
  23. from federatedml.param.callback_param import CallbackParam
  24. from federatedml.util import consts, LOGGER
  25. import copy
  26. import collections
  27. hetero_deprecated_param_list = ["early_stopping_rounds", "validation_freqs", "metrics", "use_first_metric_only"]
  28. homo_deprecated_param_list = ["validation_freqs", "metrics"]
  29. class ObjectiveParam(BaseParam):
  30. """
  31. Define objective parameters that used in federated ml.
  32. Parameters
  33. ----------
  34. objective : {None, 'cross_entropy', 'lse', 'lae', 'log_cosh', 'tweedie', 'fair', 'huber'}
  35. None in host's config, should be str in guest'config.
  36. when task_type is classification, only support 'cross_entropy',
  37. other 6 types support in regression task
  38. params : None or list
  39. should be non empty list when objective is 'tweedie','fair','huber',
  40. first element of list shoulf be a float-number large than 0.0 when objective is 'fair', 'huber',
  41. first element of list should be a float-number in [1.0, 2.0) when objective is 'tweedie'
  42. """
  43. def __init__(self, objective='cross_entropy', params=None):
  44. self.objective = objective
  45. self.params = params
  46. def check(self, task_type=None):
  47. if self.objective is None:
  48. return True
  49. descr = "objective param's"
  50. LOGGER.debug('check objective {}'.format(self.objective))
  51. if task_type not in [consts.CLASSIFICATION, consts.REGRESSION]:
  52. self.objective = self.check_and_change_lower(self.objective,
  53. ["cross_entropy", "lse", "lae", "huber", "fair",
  54. "log_cosh", "tweedie"],
  55. descr)
  56. if task_type == consts.CLASSIFICATION:
  57. if self.objective != "cross_entropy":
  58. raise ValueError("objective param's objective {} not supported".format(self.objective))
  59. elif task_type == consts.REGRESSION:
  60. self.objective = self.check_and_change_lower(self.objective,
  61. ["lse", "lae", "huber", "fair", "log_cosh", "tweedie"],
  62. descr)
  63. params = self.params
  64. if self.objective in ["huber", "fair", "tweedie"]:
  65. if type(params).__name__ != 'list' or len(params) < 1:
  66. raise ValueError(
  67. "objective param's params {} not supported, should be non-empty list".format(params))
  68. if type(params[0]).__name__ not in ["float", "int", "long"]:
  69. raise ValueError("objective param's params[0] {} not supported".format(self.params[0]))
  70. if self.objective == 'tweedie':
  71. if params[0] < 1 or params[0] >= 2:
  72. raise ValueError("in tweedie regression, objective params[0] should betweend [1, 2)")
  73. if self.objective == 'fair' or 'huber':
  74. if params[0] <= 0.0:
  75. raise ValueError("in {} regression, objective params[0] should greater than 0.0".format(
  76. self.objective))
  77. return True
  78. class DecisionTreeParam(BaseParam):
  79. """
  80. Define decision tree parameters that used in federated ml.
  81. Parameters
  82. ----------
  83. criterion_method : {"xgboost"}, default: "xgboost"
  84. the criterion function to use
  85. criterion_params: list or dict
  86. should be non empty and elements are float-numbers,
  87. if a list is offered, the first one is l2 regularization value, and the second one is
  88. l1 regularization value.
  89. if a dict is offered, make sure it contains key 'l1', and 'l2'.
  90. l1, l2 regularization values are non-negative floats.
  91. default: [0.1, 0] or {'l1':0, 'l2':0,1}
  92. max_depth: positive integer
  93. the max depth of a decision tree, default: 3
  94. min_sample_split: int
  95. least quantity of nodes to split, default: 2
  96. min_impurity_split: float
  97. least gain of a single split need to reach, default: 1e-3
  98. min_child_weight: float
  99. sum of hessian needed in child nodes. default is 0
  100. min_leaf_node: int
  101. when samples no more than min_leaf_node, it becomes a leave, default: 1
  102. max_split_nodes: positive integer
  103. we will use no more than max_split_nodes to
  104. parallel finding their splits in a batch, for memory consideration. default is 65536
  105. feature_importance_type: {'split', 'gain'}
  106. if is 'split', feature_importances calculate by feature split times,
  107. if is 'gain', feature_importances calculate by feature split gain.
  108. default: 'split'
  109. Due to the safety concern, we adjust training strategy of Hetero-SBT in FATE-1.8,
  110. When running Hetero-SBT, this parameter is now abandoned.
  111. In Hetero-SBT of FATE-1.8, guest side will compute split, gain of local features,
  112. and receive anonymous feature importance results from hosts. Hosts will compute split
  113. importance of local features.
  114. use_missing: bool, accepted True, False only, default: False
  115. use missing value in training process or not.
  116. zero_as_missing: bool
  117. regard 0 as missing value or not,
  118. will be use only if use_missing=True, default: False
  119. deterministic: bool
  120. ensure stability when computing histogram. Set this to true to ensure stable result when using
  121. same data and same parameter. But it may slow down computation.
  122. """
  123. def __init__(self, criterion_method="xgboost", criterion_params=[0.1, 0], max_depth=3,
  124. min_sample_split=2, min_impurity_split=1e-3, min_leaf_node=1,
  125. max_split_nodes=consts.MAX_SPLIT_NODES, feature_importance_type='split',
  126. n_iter_no_change=True, tol=0.001, min_child_weight=0,
  127. use_missing=False, zero_as_missing=False, deterministic=False):
  128. super(DecisionTreeParam, self).__init__()
  129. self.criterion_method = criterion_method
  130. self.criterion_params = criterion_params
  131. self.max_depth = max_depth
  132. self.min_sample_split = min_sample_split
  133. self.min_impurity_split = min_impurity_split
  134. self.min_leaf_node = min_leaf_node
  135. self.min_child_weight = min_child_weight
  136. self.max_split_nodes = max_split_nodes
  137. self.feature_importance_type = feature_importance_type
  138. self.n_iter_no_change = n_iter_no_change
  139. self.tol = tol
  140. self.use_missing = use_missing
  141. self.zero_as_missing = zero_as_missing
  142. self.deterministic = deterministic
  143. def check(self):
  144. descr = "decision tree param"
  145. self.criterion_method = self.check_and_change_lower(self.criterion_method,
  146. ["xgboost"],
  147. descr)
  148. if len(self.criterion_params) == 0:
  149. raise ValueError("decisition tree param's criterio_params should be non empty")
  150. if isinstance(self.criterion_params, list):
  151. assert len(self.criterion_params) == 2, 'length of criterion_param should be 2: l1, l2 regularization ' \
  152. 'values are needed'
  153. self.check_nonnegative_number(self.criterion_params[0], 'l2 reg value')
  154. self.check_nonnegative_number(self.criterion_params[1], 'l1 reg value')
  155. elif isinstance(self.criterion_params, dict):
  156. assert 'l1' in self.criterion_params and 'l2' in self.criterion_params, 'l1 and l2 keys are needed in ' \
  157. 'criterion_params dict'
  158. self.criterion_params = [self.criterion_params['l2'], self.criterion_params['l1']]
  159. else:
  160. raise ValueError('criterion_params should be a dict or a list contains l1, l2 reg value')
  161. if type(self.max_depth).__name__ not in ["int", "long"]:
  162. raise ValueError("decision tree param's max_depth {} not supported, should be integer".format(
  163. self.max_depth))
  164. if self.max_depth < 1:
  165. raise ValueError("decision tree param's max_depth should be positive integer, no less than 1")
  166. if type(self.min_sample_split).__name__ not in ["int", "long"]:
  167. raise ValueError("decision tree param's min_sample_split {} not supported, should be integer".format(
  168. self.min_sample_split))
  169. if type(self.min_impurity_split).__name__ not in ["int", "long", "float"]:
  170. raise ValueError("decision tree param's min_impurity_split {} not supported, should be numeric".format(
  171. self.min_impurity_split))
  172. if type(self.min_leaf_node).__name__ not in ["int", "long"]:
  173. raise ValueError("decision tree param's min_leaf_node {} not supported, should be integer".format(
  174. self.min_leaf_node))
  175. if type(self.max_split_nodes).__name__ not in ["int", "long"] or self.max_split_nodes < 1:
  176. raise ValueError("decision tree param's max_split_nodes {} not supported, " +
  177. "should be positive integer between 1 and {}".format(self.max_split_nodes,
  178. consts.MAX_SPLIT_NODES))
  179. if type(self.n_iter_no_change).__name__ != "bool":
  180. raise ValueError("decision tree param's n_iter_no_change {} not supported, should be bool type".format(
  181. self.n_iter_no_change))
  182. if type(self.tol).__name__ not in ["float", "int", "long"]:
  183. raise ValueError("decision tree param's tol {} not supported, should be numeric".format(self.tol))
  184. self.feature_importance_type = self.check_and_change_lower(self.feature_importance_type,
  185. ["split", "gain"],
  186. descr)
  187. self.check_nonnegative_number(self.min_child_weight, 'min_child_weight')
  188. self.check_boolean(self.deterministic, 'deterministic')
  189. return True
  190. class BoostingParam(BaseParam):
  191. """
  192. Basic parameter for Boosting Algorithms
  193. Parameters
  194. ----------
  195. task_type : {'classification', 'regression'}, default: 'classification'
  196. task type
  197. objective_param : ObjectiveParam Object, default: ObjectiveParam()
  198. objective param
  199. learning_rate : float, int or long
  200. the learning rate of secure boost. default: 0.3
  201. num_trees : int or float
  202. the max number of boosting round. default: 5
  203. subsample_feature_rate : float
  204. a float-number in [0, 1], default: 1.0
  205. n_iter_no_change : bool,
  206. when True and residual error less than tol, tree building process will stop. default: True
  207. bin_num: positive integer greater than 1
  208. bin number use in quantile. default: 32
  209. validation_freqs: None or positive integer or container object in python
  210. Do validation in training process or Not.
  211. if equals None, will not do validation in train process;
  212. if equals positive integer, will validate data every validation_freqs epochs passes;
  213. if container object in python, will validate data if epochs belong to this container.
  214. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  215. Default: None
  216. """
  217. def __init__(self, task_type=consts.CLASSIFICATION,
  218. objective_param=ObjectiveParam(),
  219. learning_rate=0.3, num_trees=5, subsample_feature_rate=1, n_iter_no_change=True,
  220. tol=0.0001, bin_num=32,
  221. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  222. validation_freqs=None, metrics=None, random_seed=100,
  223. binning_error=consts.DEFAULT_RELATIVE_ERROR):
  224. super(BoostingParam, self).__init__()
  225. self.task_type = task_type
  226. self.objective_param = copy.deepcopy(objective_param)
  227. self.learning_rate = learning_rate
  228. self.num_trees = num_trees
  229. self.subsample_feature_rate = subsample_feature_rate
  230. self.n_iter_no_change = n_iter_no_change
  231. self.tol = tol
  232. self.bin_num = bin_num
  233. self.predict_param = copy.deepcopy(predict_param)
  234. self.cv_param = copy.deepcopy(cv_param)
  235. self.validation_freqs = validation_freqs
  236. self.metrics = metrics
  237. self.random_seed = random_seed
  238. self.binning_error = binning_error
  239. def check(self):
  240. descr = "boosting tree param's"
  241. if self.task_type not in [consts.CLASSIFICATION, consts.REGRESSION]:
  242. raise ValueError("boosting_core tree param's task_type {} not supported, should be {} or {}".format(
  243. self.task_type, consts.CLASSIFICATION, consts.REGRESSION))
  244. self.objective_param.check(self.task_type)
  245. if type(self.learning_rate).__name__ not in ["float", "int", "long"]:
  246. raise ValueError("boosting_core tree param's learning_rate {} not supported, should be numeric".format(
  247. self.learning_rate))
  248. if type(self.subsample_feature_rate).__name__ not in ["float", "int", "long"] or \
  249. self.subsample_feature_rate < 0 or self.subsample_feature_rate > 1:
  250. raise ValueError(
  251. "boosting_core tree param's subsample_feature_rate should be a numeric number between 0 and 1")
  252. if type(self.n_iter_no_change).__name__ != "bool":
  253. raise ValueError("boosting_core tree param's n_iter_no_change {} not supported, should be bool type".format(
  254. self.n_iter_no_change))
  255. if type(self.tol).__name__ not in ["float", "int", "long"]:
  256. raise ValueError("boosting_core tree param's tol {} not supported, should be numeric".format(self.tol))
  257. if type(self.bin_num).__name__ not in ["int", "long"] or self.bin_num < 2:
  258. raise ValueError(
  259. "boosting_core tree param's bin_num {} not supported, should be positive integer greater than 1".format(
  260. self.bin_num))
  261. if self.validation_freqs is None:
  262. pass
  263. elif isinstance(self.validation_freqs, int):
  264. if self.validation_freqs < 1:
  265. raise ValueError("validation_freqs should be larger than 0 when it's integer")
  266. elif not isinstance(self.validation_freqs, collections.Container):
  267. raise ValueError("validation_freqs should be None or positive integer or container")
  268. if self.metrics is not None and not isinstance(self.metrics, list):
  269. raise ValueError("metrics should be a list")
  270. if self.random_seed is not None:
  271. assert isinstance(self.random_seed, int) and self.random_seed >= 0, 'random seed must be an integer >= 0'
  272. self.check_decimal_float(self.binning_error, descr)
  273. return True
  274. class HeteroBoostingParam(BoostingParam):
  275. """
  276. Parameters
  277. ----------
  278. encrypt_param : EncodeParam Object
  279. encrypt method use in secure boost, default: EncryptParam()
  280. encrypted_mode_calculator_param: EncryptedModeCalculatorParam object
  281. the calculation mode use in secureboost,
  282. default: EncryptedModeCalculatorParam()
  283. """
  284. def __init__(self, task_type=consts.CLASSIFICATION,
  285. objective_param=ObjectiveParam(),
  286. learning_rate=0.3, num_trees=5, subsample_feature_rate=1, n_iter_no_change=True,
  287. tol=0.0001, encrypt_param=EncryptParam(),
  288. bin_num=32,
  289. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  290. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  291. validation_freqs=None, early_stopping_rounds=None, metrics=None, use_first_metric_only=False,
  292. random_seed=100, binning_error=consts.DEFAULT_RELATIVE_ERROR):
  293. super(HeteroBoostingParam, self).__init__(task_type, objective_param, learning_rate, num_trees,
  294. subsample_feature_rate, n_iter_no_change, tol, bin_num,
  295. predict_param, cv_param, validation_freqs, metrics=metrics,
  296. random_seed=random_seed,
  297. binning_error=binning_error)
  298. self.encrypt_param = copy.deepcopy(encrypt_param)
  299. self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
  300. self.early_stopping_rounds = early_stopping_rounds
  301. self.use_first_metric_only = use_first_metric_only
  302. def check(self):
  303. super(HeteroBoostingParam, self).check()
  304. self.encrypted_mode_calculator_param.check()
  305. self.encrypt_param.check()
  306. if self.early_stopping_rounds is None:
  307. pass
  308. elif isinstance(self.early_stopping_rounds, int):
  309. if self.early_stopping_rounds < 1:
  310. raise ValueError("early stopping rounds should be larger than 0 when it's integer")
  311. if self.validation_freqs is None:
  312. raise ValueError("validation freqs must be set when early stopping is enabled")
  313. if not isinstance(self.use_first_metric_only, bool):
  314. raise ValueError("use_first_metric_only should be a boolean")
  315. return True
  316. @deprecated_param(*hetero_deprecated_param_list)
  317. class HeteroSecureBoostParam(HeteroBoostingParam):
  318. """
  319. Define boosting tree parameters that used in federated ml.
  320. Parameters
  321. ----------
  322. task_type : {'classification', 'regression'}, default: 'classification'
  323. task type
  324. tree_param : DecisionTreeParam Object, default: DecisionTreeParam()
  325. tree param
  326. objective_param : ObjectiveParam Object, default: ObjectiveParam()
  327. objective param
  328. learning_rate : float, int or long
  329. the learning rate of secure boost. default: 0.3
  330. num_trees : int or float
  331. the max number of trees to build. default: 5
  332. subsample_feature_rate : float
  333. a float-number in [0, 1], default: 1.0
  334. random_seed: int
  335. seed that controls all random functions
  336. n_iter_no_change : bool,
  337. when True and residual error less than tol, tree building process will stop. default: True
  338. encrypt_param : EncodeParam Object
  339. encrypt method use in secure boost, default: EncryptParam(), this parameter
  340. is only for hetero-secureboost
  341. bin_num: positive integer greater than 1
  342. bin number use in quantile. default: 32
  343. encrypted_mode_calculator_param: EncryptedModeCalculatorParam object
  344. the calculation mode use in secureboost, default: EncryptedModeCalculatorParam(), only for hetero-secureboost
  345. use_missing: bool
  346. use missing value in training process or not. default: False
  347. zero_as_missing: bool
  348. regard 0 as missing value or not, will be use only if use_missing=True, default: False
  349. validation_freqs: None or positive integer or container object in python
  350. Do validation in training process or Not.
  351. if equals None, will not do validation in train process;
  352. if equals positive integer, will validate data every validation_freqs epochs passes;
  353. if container object in python, will validate data if epochs belong to this container.
  354. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  355. Default: None
  356. The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to
  357. speed up training by skipping validation rounds. When it is larger than 1, a number which is
  358. divisible by "num_trees" is recommended, otherwise, you will miss the validation scores
  359. of last training iteration.
  360. early_stopping_rounds: integer larger than 0
  361. will stop training if one metric of one validation data
  362. doesn’t improve in last early_stopping_round rounds,
  363. need to set validation freqs and will check early_stopping every at every validation epoch,
  364. metrics: list, default: []
  365. Specify which metrics to be used when performing evaluation during training process.
  366. If set as empty, default metrics will be used. For regression tasks, default metrics are
  367. ['root_mean_squared_error', 'mean_absolute_error'], For binary-classificatiin tasks, default metrics
  368. are ['auc', 'ks']. For multi-classification tasks, default metrics are ['accuracy', 'precision', 'recall']
  369. use_first_metric_only: bool
  370. use only the first metric for early stopping
  371. complete_secure: bool
  372. if use complete_secure, when use complete secure, build first tree using only guest features
  373. sparse_optimization:
  374. this parameter is abandoned in FATE-1.7.1
  375. run_goss: bool
  376. activate Gradient-based One-Side Sampling, which selects large gradient and small
  377. gradient samples using top_rate and other_rate.
  378. top_rate: float, the retain ratio of large gradient data, used when run_goss is True
  379. other_rate: float, the retain ratio of small gradient data, used when run_goss is True
  380. cipher_compress_error: This param is now abandoned
  381. cipher_compress: bool, default is True, use cipher compressing to reduce computation cost and transfer cost
  382. boosting_strategy:str
  383. std: standard sbt setting
  384. mix: alternate using guest/host features to build trees. For example, the first 'tree_num_per_party' trees
  385. use guest features,
  386. the second k trees use host features, and so on
  387. layered: only support 2 party, when running layered mode, first 'host_depth' layer will use host features,
  388. and then next 'guest_depth' will only use guest features
  389. work_mode: str
  390. This parameter has the same function as boosting_strategy, but is deprecated
  391. tree_num_per_party: int, every party will alternate build 'tree_num_per_party' trees until reach max tree num, this
  392. param is valid when boosting_strategy is mix
  393. guest_depth: int, guest will build last guest_depth of a decision tree using guest features, is valid when boosting_strategy
  394. is layered
  395. host_depth: int, host will build first host_depth of a decision tree using host features, is valid when work boosting_strategy
  396. layered
  397. multi_mode: str, decide which mode to use when running multi-classification task:
  398. single_output standard gbdt multi-classification strategy
  399. multi_output every leaf give a multi-dimension predict, using multi_mode can save time
  400. by learning a model with less trees.
  401. EINI_inference: bool
  402. default is False, this option changes the inference algorithm used in predict tasks.
  403. a secure prediction method that hides decision path to enhance security in the inference
  404. step. This method is insprired by EINI inference algorithm.
  405. EINI_random_mask: bool
  406. default is False
  407. multiply predict result by a random float number to confuse original predict result. This operation further
  408. enhances the security of naive EINI algorithm.
  409. EINI_complexity_check: bool
  410. default is False
  411. check the complexity of tree models when running EINI algorithms. Complexity models are easy to hide their
  412. decision path, while simple tree models are not, therefore if a tree model is too simple, it is not allowed
  413. to run EINI predict algorithms.
  414. """
  415. def __init__(self, tree_param: DecisionTreeParam = DecisionTreeParam(), task_type=consts.CLASSIFICATION,
  416. objective_param=ObjectiveParam(),
  417. learning_rate=0.3, num_trees=5, subsample_feature_rate=1.0, n_iter_no_change=True,
  418. tol=0.0001, encrypt_param=EncryptParam(),
  419. bin_num=32,
  420. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  421. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  422. validation_freqs=None, early_stopping_rounds=None, use_missing=False, zero_as_missing=False,
  423. complete_secure=False, metrics=None, use_first_metric_only=False, random_seed=100,
  424. binning_error=consts.DEFAULT_RELATIVE_ERROR,
  425. sparse_optimization=False, run_goss=False, top_rate=0.2, other_rate=0.1,
  426. cipher_compress_error=None, cipher_compress=True, new_ver=True, boosting_strategy=consts.STD_TREE,
  427. work_mode=None, tree_num_per_party=1, guest_depth=2, host_depth=3, callback_param=CallbackParam(),
  428. multi_mode=consts.SINGLE_OUTPUT, EINI_inference=False, EINI_random_mask=False,
  429. EINI_complexity_check=False):
  430. super(HeteroSecureBoostParam, self).__init__(task_type, objective_param, learning_rate, num_trees,
  431. subsample_feature_rate, n_iter_no_change, tol, encrypt_param,
  432. bin_num, encrypted_mode_calculator_param, predict_param, cv_param,
  433. validation_freqs, early_stopping_rounds, metrics=metrics,
  434. use_first_metric_only=use_first_metric_only,
  435. random_seed=random_seed,
  436. binning_error=binning_error)
  437. self.tree_param = copy.deepcopy(tree_param)
  438. self.zero_as_missing = zero_as_missing
  439. self.use_missing = use_missing
  440. self.complete_secure = complete_secure
  441. self.sparse_optimization = sparse_optimization
  442. self.run_goss = run_goss
  443. self.top_rate = top_rate
  444. self.other_rate = other_rate
  445. self.cipher_compress_error = cipher_compress_error
  446. self.cipher_compress = cipher_compress
  447. self.new_ver = new_ver
  448. self.EINI_inference = EINI_inference
  449. self.EINI_random_mask = EINI_random_mask
  450. self.EINI_complexity_check = EINI_complexity_check
  451. self.boosting_strategy = boosting_strategy
  452. self.work_mode = work_mode
  453. self.tree_num_per_party = tree_num_per_party
  454. self.guest_depth = guest_depth
  455. self.host_depth = host_depth
  456. self.callback_param = copy.deepcopy(callback_param)
  457. self.multi_mode = multi_mode
  458. def check(self):
  459. super(HeteroSecureBoostParam, self).check()
  460. self.tree_param.check()
  461. if not isinstance(self.use_missing, bool):
  462. raise ValueError('use missing should be bool type')
  463. if not isinstance(self.zero_as_missing, bool):
  464. raise ValueError('zero as missing should be bool type')
  465. self.check_boolean(self.complete_secure, 'complete_secure')
  466. self.check_boolean(self.run_goss, 'run goss')
  467. self.check_decimal_float(self.top_rate, 'top rate')
  468. self.check_decimal_float(self.other_rate, 'other rate')
  469. self.check_positive_number(self.other_rate, 'other_rate')
  470. self.check_positive_number(self.top_rate, 'top_rate')
  471. self.check_boolean(self.new_ver, 'code version switcher')
  472. self.check_boolean(self.cipher_compress, 'cipher compress')
  473. self.check_boolean(self.EINI_inference, 'eini inference')
  474. self.check_boolean(self.EINI_random_mask, 'eini random mask')
  475. self.check_boolean(self.EINI_complexity_check, 'eini complexity check')
  476. if self.EINI_inference and self.EINI_random_mask:
  477. LOGGER.warning('To protect the inference decision path, notice that current setting will multiply'
  478. ' predict result by a random number, hence SecureBoost will return confused predict scores'
  479. ' that is not the same as the original predict scores')
  480. if self.work_mode == consts.MIX_TREE and self.EINI_inference:
  481. LOGGER.warning('Mix tree mode does not support EINI, use default predict setting')
  482. if self.work_mode is not None:
  483. self.boosting_strategy = self.work_mode
  484. if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]:
  485. raise ValueError('unsupported multi-classification mode')
  486. if self.multi_mode == consts.MULTI_OUTPUT:
  487. if self.boosting_strategy != consts.STD_TREE:
  488. raise ValueError('MO trees only works when boosting strategy is std tree')
  489. if not self.cipher_compress:
  490. raise ValueError('Mo trees only works when cipher compress is enabled')
  491. if self.boosting_strategy not in [consts.STD_TREE, consts.LAYERED_TREE, consts.MIX_TREE]:
  492. raise ValueError('unknown sbt boosting strategy{}'.format(self.boosting_strategy))
  493. for p in ["early_stopping_rounds", "validation_freqs", "metrics",
  494. "use_first_metric_only"]:
  495. # if self._warn_to_deprecate_param(p, "", ""):
  496. if self._deprecated_params_set.get(p):
  497. if "callback_param" in self.get_user_feeded():
  498. raise ValueError(f"{p} and callback param should not be set simultaneously,"
  499. f"{self._deprecated_params_set}, {self.get_user_feeded()}")
  500. else:
  501. self.callback_param.callbacks = ["PerformanceEvaluate"]
  502. break
  503. descr = "boosting_param's"
  504. if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
  505. self.callback_param.validation_freqs = self.validation_freqs
  506. if self._warn_to_deprecate_param("early_stopping_rounds", descr, "callback_param's 'early_stopping_rounds'"):
  507. self.callback_param.early_stopping_rounds = self.early_stopping_rounds
  508. if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
  509. self.callback_param.metrics = self.metrics
  510. if self._warn_to_deprecate_param("use_first_metric_only", descr, "callback_param's 'use_first_metric_only'"):
  511. self.callback_param.use_first_metric_only = self.use_first_metric_only
  512. if self.top_rate + self.other_rate >= 1:
  513. raise ValueError('sum of top rate and other rate should be smaller than 1')
  514. return True
  515. @deprecated_param(*homo_deprecated_param_list)
  516. class HomoSecureBoostParam(BoostingParam):
  517. """
  518. Parameters
  519. ----------
  520. backend: {'distributed', 'memory'}
  521. decides which backend to use when computing histograms for homo-sbt
  522. """
  523. def __init__(self, tree_param: DecisionTreeParam = DecisionTreeParam(), task_type=consts.CLASSIFICATION,
  524. objective_param=ObjectiveParam(),
  525. learning_rate=0.3, num_trees=5, subsample_feature_rate=1, n_iter_no_change=True,
  526. tol=0.0001, bin_num=32, predict_param=PredictParam(), cv_param=CrossValidationParam(),
  527. validation_freqs=None, use_missing=False, zero_as_missing=False, random_seed=100,
  528. binning_error=consts.DEFAULT_RELATIVE_ERROR, backend=consts.DISTRIBUTED_BACKEND,
  529. callback_param=CallbackParam(), multi_mode=consts.SINGLE_OUTPUT):
  530. super(HomoSecureBoostParam, self).__init__(task_type=task_type,
  531. objective_param=objective_param,
  532. learning_rate=learning_rate,
  533. num_trees=num_trees,
  534. subsample_feature_rate=subsample_feature_rate,
  535. n_iter_no_change=n_iter_no_change,
  536. tol=tol,
  537. bin_num=bin_num,
  538. predict_param=predict_param,
  539. cv_param=cv_param,
  540. validation_freqs=validation_freqs,
  541. random_seed=random_seed,
  542. binning_error=binning_error
  543. )
  544. self.use_missing = use_missing
  545. self.zero_as_missing = zero_as_missing
  546. self.tree_param = copy.deepcopy(tree_param)
  547. self.backend = backend
  548. self.callback_param = copy.deepcopy(callback_param)
  549. self.multi_mode = multi_mode
  550. def check(self):
  551. super(HomoSecureBoostParam, self).check()
  552. self.tree_param.check()
  553. if not isinstance(self.use_missing, bool):
  554. raise ValueError('use missing should be bool type')
  555. if not isinstance(self.zero_as_missing, bool):
  556. raise ValueError('zero as missing should be bool type')
  557. if self.backend not in [consts.MEMORY_BACKEND, consts.DISTRIBUTED_BACKEND]:
  558. raise ValueError('unsupported backend')
  559. if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]:
  560. raise ValueError('unsupported multi-classification mode')
  561. for p in ["validation_freqs", "metrics"]:
  562. # if self._warn_to_deprecate_param(p, "", ""):
  563. if self._deprecated_params_set.get(p):
  564. if "callback_param" in self.get_user_feeded():
  565. raise ValueError(f"{p} and callback param should not be set simultaneously,"
  566. f"{self._deprecated_params_set}, {self.get_user_feeded()}")
  567. else:
  568. self.callback_param.callbacks = ["PerformanceEvaluate"]
  569. break
  570. descr = "boosting_param's"
  571. if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
  572. self.callback_param.validation_freqs = self.validation_freqs
  573. if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
  574. self.callback_param.metrics = self.metrics
  575. if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]:
  576. raise ValueError('unsupported multi-classification mode')
  577. if self.multi_mode == consts.MULTI_OUTPUT:
  578. if self.task_type == consts.REGRESSION:
  579. raise ValueError('regression tasks not support multi-output trees')
  580. return True