logistic_regression_param.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import copy
  19. from pipeline.param.glm_param import LinearModelParam
  20. from pipeline.param.callback_param import CallbackParam
  21. from pipeline.param.cross_validation_param import CrossValidationParam
  22. from pipeline.param.encrypt_param import EncryptParam
  23. from pipeline.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  24. from pipeline.param.init_model_param import InitParam
  25. from pipeline.param.predict_param import PredictParam
  26. from pipeline.param.sqn_param import StochasticQuasiNewtonParam
  27. from pipeline.param.stepwise_param import StepwiseParam
  28. from pipeline.param import consts
  29. class LogisticParam(LinearModelParam):
  30. """
  31. Parameters used for Logistic Regression both for Homo mode or Hetero mode.
  32. Parameters
  33. ----------
  34. penalty : {'L2', 'L1' or None}
  35. Penalty method used in LR. Please note that, when using encrypted version in HomoLR,
  36. 'L1' is not supported.
  37. tol : float, default: 1e-4
  38. The tolerance of convergence
  39. alpha : float, default: 1.0
  40. Regularization strength coefficient.
  41. optimizer : {'rmsprop', 'sgd', 'adam', 'nesterov_momentum_sgd', 'sqn', 'adagrad'}, default: 'rmsprop'
  42. Optimize method, if 'sqn' has been set, sqn_param will take effect. Currently, 'sqn' support hetero mode only.
  43. batch_size : int, default: -1
  44. Batch size when updating model. -1 means use all data in a batch. i.e. Not to use mini-batch strategy.
  45. learning_rate : float, default: 0.01
  46. Learning rate
  47. max_iter : int, default: 100
  48. The maximum iteration for training.
  49. early_stop : {'diff', 'weight_diff', 'abs'}, default: 'diff'
  50. Method used to judge converge or not.
  51. a) diff: Use difference of loss between two iterations to judge whether converge.
  52. b) weight_diff: Use difference between weights of two consecutive iterations
  53. c) abs: Use the absolute value of loss to judge whether converge. i.e. if loss < eps, it is converged.
  54. Please note that for hetero-lr multi-host situation, this parameter support "weight_diff" only.
  55. decay: int or float, default: 1
  56. Decay rate for learning rate. learning rate will follow the following decay schedule.
  57. lr = lr0/(1+decay*t) if decay_sqrt is False. If decay_sqrt is True, lr = lr0 / sqrt(1+decay*t)
  58. where t is the iter number.
  59. decay_sqrt: bool, default: True
  60. lr = lr0/(1+decay*t) if decay_sqrt is False, otherwise, lr = lr0 / sqrt(1+decay*t)
  61. encrypt_param: EncryptParam object, default: default EncryptParam object
  62. encrypt param
  63. predict_param: PredictParam object, default: default PredictParam object
  64. predict param
  65. callback_param: CallbackParam object
  66. callback param
  67. cv_param: CrossValidationParam object, default: default CrossValidationParam object
  68. cv param
  69. multi_class: {'ovr'}, default: 'ovr'
  70. If it is a multi_class task, indicate what strategy to use. Currently, support 'ovr' short for one_vs_rest only.
  71. validation_freqs: int or list or tuple or set, or None, default None
  72. validation frequency during training.
  73. early_stopping_rounds: int, default: None
  74. Will stop training if one metric doesn’t improve in last early_stopping_round rounds
  75. metrics: list or None, default: None
  76. Indicate when executing evaluation during train process, which metrics will be used. If set as empty,
  77. default metrics for specific task type will be used. As for binary classification, default metrics are
  78. ['auc', 'ks']
  79. use_first_metric_only: bool, default: False
  80. Indicate whether use the first metric only for early stopping judgement.
  81. floating_point_precision: None or integer
  82. if not None, use floating_point_precision-bit to speed up calculation,
  83. e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
  84. the result by 2**floating_point_precision in the end.
  85. """
  86. def __init__(self, penalty='L2',
  87. tol=1e-4, alpha=1.0, optimizer='rmsprop',
  88. batch_size=-1, shuffle=True, batch_strategy="full", masked_rate=5,
  89. learning_rate=0.01, init_param=InitParam(),
  90. max_iter=100, early_stop='diff', encrypt_param=EncryptParam(),
  91. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  92. decay=1, decay_sqrt=True,
  93. multi_class='ovr', validation_freqs=None, early_stopping_rounds=None,
  94. stepwise_param=StepwiseParam(), floating_point_precision=23,
  95. metrics=None,
  96. use_first_metric_only=False,
  97. callback_param=CallbackParam()
  98. ):
  99. super(LogisticParam, self).__init__()
  100. self.penalty = penalty
  101. self.tol = tol
  102. self.alpha = alpha
  103. self.optimizer = optimizer
  104. self.batch_size = batch_size
  105. self.learning_rate = learning_rate
  106. self.init_param = copy.deepcopy(init_param)
  107. self.max_iter = max_iter
  108. self.early_stop = early_stop
  109. self.encrypt_param = encrypt_param
  110. self.shuffle = shuffle
  111. self.batch_strategy = batch_strategy
  112. self.masked_rate = masked_rate
  113. self.predict_param = copy.deepcopy(predict_param)
  114. self.cv_param = copy.deepcopy(cv_param)
  115. self.decay = decay
  116. self.decay_sqrt = decay_sqrt
  117. self.multi_class = multi_class
  118. self.validation_freqs = validation_freqs
  119. self.stepwise_param = copy.deepcopy(stepwise_param)
  120. self.early_stopping_rounds = early_stopping_rounds
  121. self.metrics = metrics or []
  122. self.use_first_metric_only = use_first_metric_only
  123. self.floating_point_precision = floating_point_precision
  124. self.callback_param = copy.deepcopy(callback_param)
  125. def check(self):
  126. descr = "logistic_param's"
  127. super(LogisticParam, self).check()
  128. self.predict_param.check()
  129. if self.encrypt_param.method not in [consts.PAILLIER, None]:
  130. raise ValueError(
  131. "logistic_param's encrypted method support 'Paillier' or None only")
  132. self.multi_class = self.check_and_change_lower(self.multi_class, ["ovr"], f"{descr}")
  133. return True
  134. class HomoLogisticParam(LogisticParam):
  135. """
  136. Parameters
  137. ----------
  138. aggregate_iters : int, default: 1
  139. Indicate how many iterations are aggregated once.
  140. """
  141. def __init__(self, penalty='L2',
  142. tol=1e-4, alpha=1.0, optimizer='rmsprop',
  143. batch_size=-1, learning_rate=0.01, init_param=InitParam(),
  144. max_iter=100, early_stop='diff',
  145. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  146. decay=1, decay_sqrt=True,
  147. aggregate_iters=1, multi_class='ovr', validation_freqs=None,
  148. metrics=['auc', 'ks'],
  149. callback_param=CallbackParam()
  150. ):
  151. super(HomoLogisticParam, self).__init__(penalty=penalty, tol=tol, alpha=alpha, optimizer=optimizer,
  152. batch_size=batch_size,
  153. learning_rate=learning_rate,
  154. init_param=init_param, max_iter=max_iter, early_stop=early_stop,
  155. predict_param=predict_param,
  156. cv_param=cv_param, multi_class=multi_class,
  157. validation_freqs=validation_freqs,
  158. decay=decay, decay_sqrt=decay_sqrt,
  159. metrics=metrics,
  160. callback_param=callback_param)
  161. self.aggregate_iters = aggregate_iters
  162. def check(self):
  163. super().check()
  164. if not isinstance(self.aggregate_iters, int):
  165. raise ValueError(
  166. "logistic_param's aggregate_iters {} not supported, should be int type".format(
  167. self.aggregate_iters))
  168. return True
  169. class HeteroLogisticParam(LogisticParam):
  170. def __init__(self, penalty='L2',
  171. tol=1e-4, alpha=1.0, optimizer='rmsprop',
  172. batch_size=-1, shuffle=True, batch_strategy="full", masked_rate=5,
  173. learning_rate=0.01, init_param=InitParam(),
  174. max_iter=100, early_stop='diff',
  175. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  176. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  177. decay=1, decay_sqrt=True, sqn_param=StochasticQuasiNewtonParam(),
  178. multi_class='ovr', validation_freqs=None, early_stopping_rounds=None,
  179. metrics=['auc', 'ks'], floating_point_precision=23,
  180. encrypt_param=EncryptParam(),
  181. use_first_metric_only=False, stepwise_param=StepwiseParam(),
  182. callback_param=CallbackParam()
  183. ):
  184. super(
  185. HeteroLogisticParam,
  186. self).__init__(
  187. penalty=penalty,
  188. tol=tol,
  189. alpha=alpha,
  190. optimizer=optimizer,
  191. batch_size=batch_size,
  192. shuffle=shuffle,
  193. batch_strategy=batch_strategy,
  194. masked_rate=masked_rate,
  195. learning_rate=learning_rate,
  196. init_param=init_param,
  197. max_iter=max_iter,
  198. early_stop=early_stop,
  199. predict_param=predict_param,
  200. cv_param=cv_param,
  201. decay=decay,
  202. decay_sqrt=decay_sqrt,
  203. multi_class=multi_class,
  204. validation_freqs=validation_freqs,
  205. early_stopping_rounds=early_stopping_rounds,
  206. metrics=metrics,
  207. floating_point_precision=floating_point_precision,
  208. encrypt_param=encrypt_param,
  209. use_first_metric_only=use_first_metric_only,
  210. stepwise_param=stepwise_param,
  211. callback_param=callback_param)
  212. self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
  213. self.sqn_param = copy.deepcopy(sqn_param)
  214. def check(self):
  215. super().check()
  216. self.encrypted_mode_calculator_param.check()
  217. self.sqn_param.check()
  218. return True