boosting_param.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from pipeline.param.base_param import BaseParam
  18. from pipeline.param.encrypt_param import EncryptParam
  19. from pipeline.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  20. from pipeline.param.cross_validation_param import CrossValidationParam
  21. from pipeline.param.predict_param import PredictParam
  22. from pipeline.param import consts
  23. from pipeline.param.callback_param import CallbackParam
  24. import copy
  25. import collections
  26. class ObjectiveParam(BaseParam):
  27. """
  28. Define objective parameters that used in federated ml.
  29. Parameters
  30. ----------
  31. objective : {None, 'cross_entropy', 'lse', 'lae', 'log_cosh', 'tweedie', 'fair', 'huber'}
  32. None in host's config, should be str in guest'config.
  33. when task_type is classification, only support 'cross_entropy',
  34. other 6 types support in regression task
  35. params : None or list
  36. should be non empty list when objective is 'tweedie','fair','huber',
  37. first element of list shoulf be a float-number large than 0.0 when objective is 'fair', 'huber',
  38. first element of list should be a float-number in [1.0, 2.0) when objective is 'tweedie'
  39. """
  40. def __init__(self, objective='cross_entropy', params=None):
  41. self.objective = objective
  42. self.params = params
  43. def check(self, task_type=None):
  44. if self.objective is None:
  45. return True
  46. descr = "objective param's"
  47. if task_type not in [consts.CLASSIFICATION, consts.REGRESSION]:
  48. self.objective = self.check_and_change_lower(self.objective,
  49. ["cross_entropy", "lse", "lae", "huber", "fair",
  50. "log_cosh", "tweedie"],
  51. descr)
  52. if task_type == consts.CLASSIFICATION:
  53. if self.objective != "cross_entropy":
  54. raise ValueError("objective param's objective {} not supported".format(self.objective))
  55. elif task_type == consts.REGRESSION:
  56. self.objective = self.check_and_change_lower(self.objective,
  57. ["lse", "lae", "huber", "fair", "log_cosh", "tweedie"],
  58. descr)
  59. params = self.params
  60. if self.objective in ["huber", "fair", "tweedie"]:
  61. if type(params).__name__ != 'list' or len(params) < 1:
  62. raise ValueError(
  63. "objective param's params {} not supported, should be non-empty list".format(params))
  64. if type(params[0]).__name__ not in ["float", "int", "long"]:
  65. raise ValueError("objective param's params[0] {} not supported".format(self.params[0]))
  66. if self.objective == 'tweedie':
  67. if params[0] < 1 or params[0] >= 2:
  68. raise ValueError("in tweedie regression, objective params[0] should betweend [1, 2)")
  69. if self.objective == 'fair' or 'huber':
  70. if params[0] <= 0.0:
  71. raise ValueError("in {} regression, objective params[0] should greater than 0.0".format(
  72. self.objective))
  73. return True
  74. class DecisionTreeParam(BaseParam):
  75. """
  76. Define decision tree parameters that used in federated ml.
  77. Parameters
  78. ----------
  79. criterion_method : {"xgboost"}, default: "xgboost"
  80. the criterion function to use
  81. criterion_params: list or dict
  82. should be non empty and elements are float-numbers,
  83. if a list is offered, the first one is l2 regularization value, and the second one is
  84. l1 regularization value.
  85. if a dict is offered, make sure it contains key 'l1', and 'l2'.
  86. l1, l2 regularization values are non-negative floats.
  87. default: [0.1, 0] or {'l1':0, 'l2':0,1}
  88. max_depth: positive integer
  89. the max depth of a decision tree, default: 3
  90. min_sample_split: int
  91. least quantity of nodes to split, default: 2
  92. min_impurity_split: float
  93. least gain of a single split need to reach, default: 1e-3
  94. min_child_weight: float
  95. sum of hessian needed in child nodes. default is 0
  96. min_leaf_node: int
  97. when samples no more than min_leaf_node, it becomes a leave, default: 1
  98. max_split_nodes: positive integer
  99. we will use no more than max_split_nodes to
  100. parallel finding their splits in a batch, for memory consideration. default is 65536
  101. feature_importance_type: {'split', 'gain'}
  102. if is 'split', feature_importances calculate by feature split times,
  103. if is 'gain', feature_importances calculate by feature split gain.
  104. default: 'split'
  105. Due to the safety concern, we adjust training strategy of Hetero-SBT in FATE-1.8,
  106. When running Hetero-SBT, this parameter is now abandoned.
  107. In Hetero-SBT of FATE-1.8, guest side will compute split, gain of local features,
  108. and receive anonymous feature importance results from hosts. Hosts will compute split
  109. importance of local features.
  110. use_missing: bool, accepted True, False only, use missing value in training process or not. default: False
  111. zero_as_missing: bool
  112. regard 0 as missing value or not,
  113. will be use only if use_missing=True, default: False
  114. deterministic: bool
  115. ensure stability when computing histogram. Set this to true to ensure stable result when using
  116. same data and same parameter. But it may slow down computation.
  117. """
  118. def __init__(self, criterion_method="xgboost", criterion_params=[0.1, 0], max_depth=3,
  119. min_sample_split=2, min_impurity_split=1e-3, min_leaf_node=1,
  120. max_split_nodes=consts.MAX_SPLIT_NODES, feature_importance_type='split',
  121. n_iter_no_change=True, tol=0.001, min_child_weight=0,
  122. use_missing=False, zero_as_missing=False, deterministic=False):
  123. super(DecisionTreeParam, self).__init__()
  124. self.criterion_method = criterion_method
  125. self.criterion_params = criterion_params
  126. self.max_depth = max_depth
  127. self.min_sample_split = min_sample_split
  128. self.min_impurity_split = min_impurity_split
  129. self.min_leaf_node = min_leaf_node
  130. self.min_child_weight = min_child_weight
  131. self.max_split_nodes = max_split_nodes
  132. self.feature_importance_type = feature_importance_type
  133. self.n_iter_no_change = n_iter_no_change
  134. self.tol = tol
  135. self.use_missing = use_missing
  136. self.zero_as_missing = zero_as_missing
  137. self.deterministic = deterministic
  138. def check(self):
  139. descr = "decision tree param"
  140. self.criterion_method = self.check_and_change_lower(self.criterion_method,
  141. ["xgboost"],
  142. descr)
  143. if len(self.criterion_params) == 0:
  144. raise ValueError("decisition tree param's criterio_params should be non empty")
  145. if isinstance(self.criterion_params, list):
  146. assert len(self.criterion_params) == 2, 'length of criterion_param should be 2: l1, l2 regularization ' \
  147. 'values are needed'
  148. self.check_nonnegative_number(self.criterion_params[0], 'l2 reg value')
  149. self.check_nonnegative_number(self.criterion_params[1], 'l1 reg value')
  150. elif isinstance(self.criterion_params, dict):
  151. assert 'l1' in self.criterion_params and 'l2' in self.criterion_params, 'l1 and l2 keys are needed in ' \
  152. 'criterion_params dict'
  153. self.criterion_params = [self.criterion_params['l2'], self.criterion_params['l1']]
  154. else:
  155. raise ValueError('criterion_params should be a dict or a list contains l1, l2 reg value')
  156. if type(self.max_depth).__name__ not in ["int", "long"]:
  157. raise ValueError("decision tree param's max_depth {} not supported, should be integer".format(
  158. self.max_depth))
  159. if self.max_depth < 1:
  160. raise ValueError("decision tree param's max_depth should be positive integer, no less than 1")
  161. if type(self.min_sample_split).__name__ not in ["int", "long"]:
  162. raise ValueError("decision tree param's min_sample_split {} not supported, should be integer".format(
  163. self.min_sample_split))
  164. if type(self.min_impurity_split).__name__ not in ["int", "long", "float"]:
  165. raise ValueError("decision tree param's min_impurity_split {} not supported, should be numeric".format(
  166. self.min_impurity_split))
  167. if type(self.min_leaf_node).__name__ not in ["int", "long"]:
  168. raise ValueError("decision tree param's min_leaf_node {} not supported, should be integer".format(
  169. self.min_leaf_node))
  170. if type(self.max_split_nodes).__name__ not in ["int", "long"] or self.max_split_nodes < 1:
  171. raise ValueError("decision tree param's max_split_nodes {} not supported, " +
  172. "should be positive integer between 1 and {}".format(self.max_split_nodes,
  173. consts.MAX_SPLIT_NODES))
  174. if type(self.n_iter_no_change).__name__ != "bool":
  175. raise ValueError("decision tree param's n_iter_no_change {} not supported, should be bool type".format(
  176. self.n_iter_no_change))
  177. if type(self.tol).__name__ not in ["float", "int", "long"]:
  178. raise ValueError("decision tree param's tol {} not supported, should be numeric".format(self.tol))
  179. self.feature_importance_type = self.check_and_change_lower(self.feature_importance_type,
  180. ["split", "gain"],
  181. descr)
  182. self.check_nonnegative_number(self.min_child_weight, 'min_child_weight')
  183. self.check_boolean(self.deterministic, 'deterministic')
  184. return True
  185. class BoostingParam(BaseParam):
  186. """
  187. Basic parameter for Boosting Algorithms
  188. Parameters
  189. ----------
  190. task_type : {'classification', 'regression'}, default: 'classification'
  191. task type
  192. objective_param : ObjectiveParam Object, default: ObjectiveParam()
  193. objective param
  194. learning_rate : float, int or long
  195. the learning rate of secure boost. default: 0.3
  196. num_trees : int or float
  197. the max number of boosting round. default: 5
  198. subsample_feature_rate : float
  199. a float-number in [0, 1], default: 1.0
  200. n_iter_no_change : bool,
  201. when True and residual error less than tol, tree building process will stop. default: True
  202. bin_num: positive integer greater than 1
  203. bin number use in quantile. default: 32
  204. validation_freqs: None or positive integer or container object in python
  205. Do validation in training process or Not.
  206. if equals None, will not do validation in train process;
  207. if equals positive integer, will validate data every validation_freqs epochs passes;
  208. if container object in python, will validate data if epochs belong to this container.
  209. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  210. Default: None
  211. """
  212. def __init__(self, task_type=consts.CLASSIFICATION,
  213. objective_param=ObjectiveParam(),
  214. learning_rate=0.3, num_trees=5, subsample_feature_rate=1, n_iter_no_change=True,
  215. tol=0.0001, bin_num=32,
  216. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  217. validation_freqs=None, metrics=None, random_seed=100,
  218. binning_error=consts.DEFAULT_RELATIVE_ERROR):
  219. super(BoostingParam, self).__init__()
  220. self.task_type = task_type
  221. self.objective_param = copy.deepcopy(objective_param)
  222. self.learning_rate = learning_rate
  223. self.num_trees = num_trees
  224. self.subsample_feature_rate = subsample_feature_rate
  225. self.n_iter_no_change = n_iter_no_change
  226. self.tol = tol
  227. self.bin_num = bin_num
  228. self.predict_param = copy.deepcopy(predict_param)
  229. self.cv_param = copy.deepcopy(cv_param)
  230. self.validation_freqs = validation_freqs
  231. self.metrics = metrics
  232. self.random_seed = random_seed
  233. self.binning_error = binning_error
  234. def check(self):
  235. descr = "boosting tree param's"
  236. if self.task_type not in [consts.CLASSIFICATION, consts.REGRESSION]:
  237. raise ValueError("boosting_core tree param's task_type {} not supported, should be {} or {}".format(
  238. self.task_type, consts.CLASSIFICATION, consts.REGRESSION))
  239. self.objective_param.check(self.task_type)
  240. if type(self.learning_rate).__name__ not in ["float", "int", "long"]:
  241. raise ValueError("boosting_core tree param's learning_rate {} not supported, should be numeric".format(
  242. self.learning_rate))
  243. if type(self.subsample_feature_rate).__name__ not in ["float", "int", "long"] or \
  244. self.subsample_feature_rate < 0 or self.subsample_feature_rate > 1:
  245. raise ValueError(
  246. "boosting_core tree param's subsample_feature_rate should be a numeric number between 0 and 1")
  247. if type(self.n_iter_no_change).__name__ != "bool":
  248. raise ValueError("boosting_core tree param's n_iter_no_change {} not supported, should be bool type".format(
  249. self.n_iter_no_change))
  250. if type(self.tol).__name__ not in ["float", "int", "long"]:
  251. raise ValueError("boosting_core tree param's tol {} not supported, should be numeric".format(self.tol))
  252. if type(self.bin_num).__name__ not in ["int", "long"] or self.bin_num < 2:
  253. raise ValueError(
  254. "boosting_core tree param's bin_num {} not supported, should be positive integer greater than 1".format(
  255. self.bin_num))
  256. if self.validation_freqs is None:
  257. pass
  258. elif isinstance(self.validation_freqs, int):
  259. if self.validation_freqs < 1:
  260. raise ValueError("validation_freqs should be larger than 0 when it's integer")
  261. elif not isinstance(self.validation_freqs, collections.Container):
  262. raise ValueError("validation_freqs should be None or positive integer or container")
  263. if self.metrics is not None and not isinstance(self.metrics, list):
  264. raise ValueError("metrics should be a list")
  265. if self.random_seed is not None:
  266. assert isinstance(self.random_seed, int) and self.random_seed >= 0, 'random seed must be an integer >= 0'
  267. self.check_decimal_float(self.binning_error, descr)
  268. return True
  269. class HeteroBoostingParam(BoostingParam):
  270. """
  271. Parameters
  272. ----------
  273. encrypt_param : EncodeParam Object
  274. encrypt method use in secure boost, default: EncryptParam()
  275. encrypted_mode_calculator_param: EncryptedModeCalculatorParam object
  276. the calculation mode use in secureboost,
  277. default: EncryptedModeCalculatorParam()
  278. """
  279. def __init__(self, task_type=consts.CLASSIFICATION,
  280. objective_param=ObjectiveParam(),
  281. learning_rate=0.3, num_trees=5, subsample_feature_rate=1, n_iter_no_change=True,
  282. tol=0.0001, encrypt_param=EncryptParam(),
  283. bin_num=32,
  284. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  285. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  286. validation_freqs=None, early_stopping_rounds=None, metrics=None, use_first_metric_only=False,
  287. random_seed=100, binning_error=consts.DEFAULT_RELATIVE_ERROR):
  288. super(HeteroBoostingParam, self).__init__(task_type, objective_param, learning_rate, num_trees,
  289. subsample_feature_rate, n_iter_no_change, tol, bin_num,
  290. predict_param, cv_param, validation_freqs, metrics=metrics,
  291. random_seed=random_seed,
  292. binning_error=binning_error)
  293. self.encrypt_param = copy.deepcopy(encrypt_param)
  294. self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
  295. self.early_stopping_rounds = early_stopping_rounds
  296. self.use_first_metric_only = use_first_metric_only
  297. def check(self):
  298. super(HeteroBoostingParam, self).check()
  299. self.encrypted_mode_calculator_param.check()
  300. self.encrypt_param.check()
  301. if self.early_stopping_rounds is None:
  302. pass
  303. elif isinstance(self.early_stopping_rounds, int):
  304. if self.early_stopping_rounds < 1:
  305. raise ValueError("early stopping rounds should be larger than 0 when it's integer")
  306. if self.validation_freqs is None:
  307. raise ValueError("validation freqs must be set when early stopping is enabled")
  308. if not isinstance(self.use_first_metric_only, bool):
  309. raise ValueError("use_first_metric_only should be a boolean")
  310. return True
  311. class HeteroSecureBoostParam(HeteroBoostingParam):
  312. """
  313. Define boosting tree parameters that used in federated ml.
  314. Parameters
  315. ----------
  316. task_type : {'classification', 'regression'}, default: 'classification'
  317. task type
  318. tree_param : DecisionTreeParam Object, default: DecisionTreeParam()
  319. tree param
  320. objective_param : ObjectiveParam Object, default: ObjectiveParam()
  321. objective param
  322. learning_rate : float, int or long
  323. the learning rate of secure boost. default: 0.3
  324. num_trees : int or float
  325. the max number of trees to build. default: 5
  326. subsample_feature_rate : float
  327. a float-number in [0, 1], default: 1.0
  328. random_seed: int
  329. seed that controls all random functions
  330. n_iter_no_change : bool,
  331. when True and residual error less than tol, tree building process will stop. default: True
  332. encrypt_param : EncodeParam Object
  333. encrypt method use in secure boost, default: EncryptParam(), this parameter
  334. is only for hetero-secureboost
  335. bin_num: positive integer greater than 1
  336. bin number use in quantile. default: 32
  337. encrypted_mode_calculator_param: EncryptedModeCalculatorParam object
  338. the calculation mode use in secureboost, default: EncryptedModeCalculatorParam(), only for hetero-secureboost
  339. use_missing: bool
  340. use missing value in training process or not. default: False
  341. zero_as_missing: bool
  342. regard 0 as missing value or not, will be use only if use_missing=True, default: False
  343. validation_freqs: None or positive integer or container object in python
  344. Do validation in training process or Not.
  345. if equals None, will not do validation in train process;
  346. if equals positive integer, will validate data every validation_freqs epochs passes;
  347. if container object in python, will validate data if epochs belong to this container.
  348. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  349. Default: None
  350. The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to
  351. speed up training by skipping validation rounds. When it is larger than 1, a number which is
  352. divisible by "num_trees" is recommended, otherwise, you will miss the validation scores
  353. of last training iteration.
  354. early_stopping_rounds: integer larger than 0
  355. will stop training if one metric of one validation data
  356. doesn’t improve in last early_stopping_round rounds,
  357. need to set validation freqs and will check early_stopping every at every validation epoch,
  358. metrics: list, default: []
  359. Specify which metrics to be used when performing evaluation during training process.
  360. If set as empty, default metrics will be used. For regression tasks, default metrics are
  361. ['root_mean_squared_error', 'mean_absolute_error'], For binary-classificatiin tasks, default metrics
  362. are ['auc', 'ks']. For multi-classification tasks, default metrics are ['accuracy', 'precision', 'recall']
  363. use_first_metric_only: bool
  364. use only the first metric for early stopping
  365. complete_secure: bool
  366. if use complete_secure, when use complete secure, build first tree using only guest features
  367. sparse_optimization:
  368. this parameter is abandoned in FATE-1.7.1
  369. run_goss: bool
  370. activate Gradient-based One-Side Sampling, which selects large gradient and small
  371. gradient samples using top_rate and other_rate.
  372. top_rate: float, the retain ratio of large gradient data, used when run_goss is True
  373. other_rate: float, the retain ratio of small gradient data, used when run_goss is True
  374. cipher_compress_error: This param is now abandoned
  375. cipher_compress: bool, default is True, use cipher compressing to reduce computation cost and transfer cost
  376. boosting_strategy:str
  377. std: standard sbt setting
  378. mix: alternate using guest/host features to build trees. For example, the first 'tree_num_per_party' trees
  379. use guest features,
  380. the second k trees use host features, and so on
  381. layered: only support 2 party, when running layered mode, first 'host_depth' layer will use host features,
  382. and then next 'guest_depth' will only use guest features
  383. work_mode: str
  384. This parameter has the same function as boosting_strategy, but is deprecated
  385. tree_num_per_party: int, every party will alternate build 'tree_num_per_party' trees until reach max tree num, this
  386. param is valid when boosting_strategy is mix
  387. guest_depth: int, guest will build last guest_depth of a decision tree using guest features, is valid when boosting_strategy
  388. is layered
  389. host_depth: int, host will build first host_depth of a decision tree using host features, is valid when work boosting_strategy
  390. layered
  391. multi_mode: str, decide which mode to use when running multi-classification task:
  392. single_output standard gbdt multi-classification strategy
  393. multi_output every leaf give a multi-dimension predict, using multi_mode can save time
  394. by learning a model with less trees.
  395. EINI_inference: bool
  396. default is False, this option changes the inference algorithm used in predict tasks.
  397. a secure prediction method that hides decision path to enhance security in the inference
  398. step. This method is insprired by EINI inference algorithm.
  399. EINI_random_mask: bool
  400. default is False
  401. multiply predict result by a random float number to confuse original predict result. This operation further
  402. enhances the security of naive EINI algorithm.
  403. EINI_complexity_check: bool
  404. default is False
  405. check the complexity of tree models when running EINI algorithms. Complexity models are easy to hide their
  406. decision path, while simple tree models are not, therefore if a tree model is too simple, it is not allowed
  407. to run EINI predict algorithms.
  408. """
  409. def __init__(self, tree_param: DecisionTreeParam = DecisionTreeParam(), task_type=consts.CLASSIFICATION,
  410. objective_param=ObjectiveParam(),
  411. learning_rate=0.3, num_trees=5, subsample_feature_rate=1.0, n_iter_no_change=True,
  412. tol=0.0001, encrypt_param=EncryptParam(),
  413. bin_num=32,
  414. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  415. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  416. validation_freqs=None, early_stopping_rounds=None, use_missing=False, zero_as_missing=False,
  417. complete_secure=False, metrics=None, use_first_metric_only=False, random_seed=100,
  418. binning_error=consts.DEFAULT_RELATIVE_ERROR,
  419. sparse_optimization=False, run_goss=False, top_rate=0.2, other_rate=0.1,
  420. cipher_compress_error=None, cipher_compress=True, new_ver=True, boosting_strategy=consts.STD_TREE,
  421. work_mode=None, tree_num_per_party=1, guest_depth=2, host_depth=3, callback_param=CallbackParam(),
  422. multi_mode=consts.SINGLE_OUTPUT, EINI_inference=False, EINI_random_mask=False,
  423. EINI_complexity_check=False):
  424. super(HeteroSecureBoostParam, self).__init__(task_type, objective_param, learning_rate, num_trees,
  425. subsample_feature_rate, n_iter_no_change, tol, encrypt_param,
  426. bin_num, encrypted_mode_calculator_param, predict_param, cv_param,
  427. validation_freqs, early_stopping_rounds, metrics=metrics,
  428. use_first_metric_only=use_first_metric_only,
  429. random_seed=random_seed,
  430. binning_error=binning_error)
  431. self.tree_param = copy.deepcopy(tree_param)
  432. self.zero_as_missing = zero_as_missing
  433. self.use_missing = use_missing
  434. self.complete_secure = complete_secure
  435. self.sparse_optimization = sparse_optimization
  436. self.run_goss = run_goss
  437. self.top_rate = top_rate
  438. self.other_rate = other_rate
  439. self.cipher_compress_error = cipher_compress_error
  440. self.cipher_compress = cipher_compress
  441. self.new_ver = new_ver
  442. self.EINI_inference = EINI_inference
  443. self.EINI_random_mask = EINI_random_mask
  444. self.EINI_complexity_check = EINI_complexity_check
  445. self.boosting_strategy = boosting_strategy
  446. self.work_mode = work_mode
  447. self.tree_num_per_party = tree_num_per_party
  448. self.guest_depth = guest_depth
  449. self.host_depth = host_depth
  450. self.callback_param = copy.deepcopy(callback_param)
  451. self.multi_mode = multi_mode
  452. def check(self):
  453. super(HeteroSecureBoostParam, self).check()
  454. self.tree_param.check()
  455. if not isinstance(self.use_missing, bool):
  456. raise ValueError('use missing should be bool type')
  457. if not isinstance(self.zero_as_missing, bool):
  458. raise ValueError('zero as missing should be bool type')
  459. self.check_boolean(self.complete_secure, 'complete_secure')
  460. self.check_boolean(self.run_goss, 'run goss')
  461. self.check_decimal_float(self.top_rate, 'top rate')
  462. self.check_decimal_float(self.other_rate, 'other rate')
  463. self.check_positive_number(self.other_rate, 'other_rate')
  464. self.check_positive_number(self.top_rate, 'top_rate')
  465. self.check_boolean(self.new_ver, 'code version switcher')
  466. self.check_boolean(self.cipher_compress, 'cipher compress')
  467. self.check_boolean(self.EINI_inference, 'eini inference')
  468. self.check_boolean(self.EINI_random_mask, 'eini random mask')
  469. self.check_boolean(self.EINI_complexity_check, 'eini complexity check')
  470. if self.work_mode is not None:
  471. self.boosting_strategy = self.work_mode
  472. if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]:
  473. raise ValueError('unsupported multi-classification mode')
  474. if self.multi_mode == consts.MULTI_OUTPUT:
  475. if self.boosting_strategy != consts.STD_TREE:
  476. raise ValueError('MO trees only works when boosting strategy is std tree')
  477. if not self.cipher_compress:
  478. raise ValueError('Mo trees only works when cipher compress is enabled')
  479. if self.boosting_strategy not in [consts.STD_TREE, consts.LAYERED_TREE, consts.MIX_TREE]:
  480. raise ValueError('unknown sbt boosting strategy{}'.format(self.boosting_strategy))
  481. for p in ["early_stopping_rounds", "validation_freqs", "metrics",
  482. "use_first_metric_only"]:
  483. # if self._warn_to_deprecate_param(p, "", ""):
  484. if self._deprecated_params_set.get(p):
  485. if "callback_param" in self.get_user_feeded():
  486. raise ValueError(f"{p} and callback param should not be set simultaneously,"
  487. f"{self._deprecated_params_set}, {self.get_user_feeded()}")
  488. else:
  489. self.callback_param.callbacks = ["PerformanceEvaluate"]
  490. break
  491. descr = "boosting_param's"
  492. if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
  493. self.callback_param.validation_freqs = self.validation_freqs
  494. if self._warn_to_deprecate_param("early_stopping_rounds", descr, "callback_param's 'early_stopping_rounds'"):
  495. self.callback_param.early_stopping_rounds = self.early_stopping_rounds
  496. if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
  497. self.callback_param.metrics = self.metrics
  498. if self._warn_to_deprecate_param("use_first_metric_only", descr, "callback_param's 'use_first_metric_only'"):
  499. self.callback_param.use_first_metric_only = self.use_first_metric_only
  500. if self.top_rate + self.other_rate >= 1:
  501. raise ValueError('sum of top rate and other rate should be smaller than 1')
  502. return True
  503. class HomoSecureBoostParam(BoostingParam):
  504. """
  505. Parameters
  506. ----------
  507. backend: {'distributed', 'memory'}
  508. decides which backend to use when computing histograms for homo-sbt
  509. """
  510. def __init__(self, tree_param: DecisionTreeParam = DecisionTreeParam(), task_type=consts.CLASSIFICATION,
  511. objective_param=ObjectiveParam(),
  512. learning_rate=0.3, num_trees=5, subsample_feature_rate=1, n_iter_no_change=True,
  513. tol=0.0001, bin_num=32, predict_param=PredictParam(), cv_param=CrossValidationParam(),
  514. validation_freqs=None, use_missing=False, zero_as_missing=False, random_seed=100,
  515. binning_error=consts.DEFAULT_RELATIVE_ERROR, backend=consts.DISTRIBUTED_BACKEND,
  516. callback_param=CallbackParam(), multi_mode=consts.SINGLE_OUTPUT):
  517. super(HomoSecureBoostParam, self).__init__(task_type=task_type,
  518. objective_param=objective_param,
  519. learning_rate=learning_rate,
  520. num_trees=num_trees,
  521. subsample_feature_rate=subsample_feature_rate,
  522. n_iter_no_change=n_iter_no_change,
  523. tol=tol,
  524. bin_num=bin_num,
  525. predict_param=predict_param,
  526. cv_param=cv_param,
  527. validation_freqs=validation_freqs,
  528. random_seed=random_seed,
  529. binning_error=binning_error
  530. )
  531. self.use_missing = use_missing
  532. self.zero_as_missing = zero_as_missing
  533. self.tree_param = copy.deepcopy(tree_param)
  534. self.backend = backend
  535. self.callback_param = copy.deepcopy(callback_param)
  536. self.multi_mode = multi_mode
  537. def check(self):
  538. super(HomoSecureBoostParam, self).check()
  539. self.tree_param.check()
  540. if not isinstance(self.use_missing, bool):
  541. raise ValueError('use missing should be bool type')
  542. if not isinstance(self.zero_as_missing, bool):
  543. raise ValueError('zero as missing should be bool type')
  544. if self.backend not in [consts.MEMORY_BACKEND, consts.DISTRIBUTED_BACKEND]:
  545. raise ValueError('unsupported backend')
  546. if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]:
  547. raise ValueError('unsupported multi-classification mode')
  548. for p in ["validation_freqs", "metrics"]:
  549. # if self._warn_to_deprecate_param(p, "", ""):
  550. if self._deprecated_params_set.get(p):
  551. if "callback_param" in self.get_user_feeded():
  552. raise ValueError(f"{p} and callback param should not be set simultaneously,"
  553. f"{self._deprecated_params_set}, {self.get_user_feeded()}")
  554. else:
  555. self.callback_param.callbacks = ["PerformanceEvaluate"]
  556. break
  557. descr = "boosting_param's"
  558. if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
  559. self.callback_param.validation_freqs = self.validation_freqs
  560. if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
  561. self.callback_param.metrics = self.metrics
  562. if self.multi_mode not in [consts.SINGLE_OUTPUT, consts.MULTI_OUTPUT]:
  563. raise ValueError('unsupported multi-classification mode')
  564. if self.multi_mode == consts.MULTI_OUTPUT:
  565. if self.task_type == consts.REGRESSION:
  566. raise ValueError('regression tasks not support multi-output trees')
  567. return True