logistic_regression_param.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import copy
  19. from federatedml.param.base_param import deprecated_param
  20. from federatedml.param.glm_param import LinearModelParam
  21. from federatedml.param.callback_param import CallbackParam
  22. from federatedml.param.cross_validation_param import CrossValidationParam
  23. from federatedml.param.encrypt_param import EncryptParam
  24. from federatedml.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  25. from federatedml.param.init_model_param import InitParam
  26. from federatedml.param.predict_param import PredictParam
  27. from federatedml.param.sqn_param import StochasticQuasiNewtonParam
  28. from federatedml.param.stepwise_param import StepwiseParam
  29. from federatedml.util import consts
  30. class LogisticParam(LinearModelParam):
  31. """
  32. Parameters used for Logistic Regression both for Homo mode or Hetero mode.
  33. Parameters
  34. ----------
  35. penalty : {'L2', 'L1' or None}
  36. Penalty method used in LR. Please note that, when using encrypted version in HomoLR,
  37. 'L1' is not supported.
  38. tol : float, default: 1e-4
  39. The tolerance of convergence
  40. alpha : float, default: 1.0
  41. Regularization strength coefficient.
  42. optimizer : {'rmsprop', 'sgd', 'adam', 'nesterov_momentum_sgd', 'adagrad'}, default: 'rmsprop'
  43. Optimize method.
  44. batch_strategy : str, {'full', 'random'}, default: "full"
  45. Strategy to generate batch data.
  46. a) full: use full data to generate batch_data, batch_nums every iteration is ceil(data_size / batch_size)
  47. b) random: select data randomly from full data, batch_num will be 1 every iteration.
  48. batch_size : int, default: -1
  49. Batch size when updating model. -1 means use all data in a batch. i.e. Not to use mini-batch strategy.
  50. shuffle : bool, default: True
  51. Work only in hetero logistic regression, batch data will be shuffle in every iteration.
  52. masked_rate: int, float: default: 5
  53. Use masked data to enhance security of hetero logistic regression
  54. learning_rate : float, default: 0.01
  55. Learning rate
  56. max_iter : int, default: 100
  57. The maximum iteration for training.
  58. early_stop : {'diff', 'weight_diff', 'abs'}, default: 'diff'
  59. Method used to judge converge or not.
  60. a) diff: Use difference of loss between two iterations to judge whether converge.
  61. b) weight_diff: Use difference between weights of two consecutive iterations
  62. c) abs: Use the absolute value of loss to judge whether converge. i.e. if loss < eps, it is converged.
  63. Please note that for hetero-lr multi-host situation, this parameter support "weight_diff" only.
  64. In homo-lr, weight_diff is not supported
  65. decay: int or float, default: 1
  66. Decay rate for learning rate. learning rate will follow the following decay schedule.
  67. lr = lr0/(1+decay*t) if decay_sqrt is False. If decay_sqrt is True, lr = lr0 / sqrt(1+decay*t)
  68. where t is the iter number.
  69. decay_sqrt: bool, default: True
  70. lr = lr0/(1+decay*t) if decay_sqrt is False, otherwise, lr = lr0 / sqrt(1+decay*t)
  71. encrypt_param: EncryptParam object, default: default EncryptParam object
  72. encrypt param
  73. predict_param: PredictParam object, default: default PredictParam object
  74. predict param
  75. callback_param: CallbackParam object
  76. callback param
  77. cv_param: CrossValidationParam object, default: default CrossValidationParam object
  78. cv param
  79. multi_class: {'ovr'}, default: 'ovr'
  80. If it is a multi_class task, indicate what strategy to use. Currently, support 'ovr' short for one_vs_rest only.
  81. validation_freqs: int or list or tuple or set, or None, default None
  82. validation frequency during training.
  83. early_stopping_rounds: int, default: None
  84. Will stop training if one metric doesn’t improve in last early_stopping_round rounds
  85. metrics: list or None, default: None
  86. Indicate when executing evaluation during train process, which metrics will be used. If set as empty,
  87. default metrics for specific task type will be used. As for binary classification, default metrics are
  88. ['auc', 'ks']
  89. use_first_metric_only: bool, default: False
  90. Indicate whether use the first metric only for early stopping judgement.
  91. floating_point_precision: None or integer
  92. if not None, use floating_point_precision-bit to speed up calculation,
  93. e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
  94. the result by 2**floating_point_precision in the end.
  95. """
  96. def __init__(self, penalty='L2',
  97. tol=1e-4, alpha=1.0, optimizer='rmsprop',
  98. batch_size=-1, shuffle=True, batch_strategy="full", masked_rate=5,
  99. learning_rate=0.01, init_param=InitParam(),
  100. max_iter=100, early_stop='diff', encrypt_param=EncryptParam(),
  101. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  102. decay=1, decay_sqrt=True,
  103. multi_class='ovr', validation_freqs=None, early_stopping_rounds=None,
  104. stepwise_param=StepwiseParam(), floating_point_precision=23,
  105. metrics=None,
  106. use_first_metric_only=False,
  107. callback_param=CallbackParam()
  108. ):
  109. super(LogisticParam, self).__init__()
  110. self.penalty = penalty
  111. self.tol = tol
  112. self.alpha = alpha
  113. self.optimizer = optimizer
  114. self.batch_size = batch_size
  115. self.learning_rate = learning_rate
  116. self.init_param = copy.deepcopy(init_param)
  117. self.max_iter = max_iter
  118. self.early_stop = early_stop
  119. self.encrypt_param = encrypt_param
  120. self.shuffle = shuffle
  121. self.batch_strategy = batch_strategy
  122. self.masked_rate = masked_rate
  123. self.predict_param = copy.deepcopy(predict_param)
  124. self.cv_param = copy.deepcopy(cv_param)
  125. self.decay = decay
  126. self.decay_sqrt = decay_sqrt
  127. self.multi_class = multi_class
  128. self.validation_freqs = validation_freqs
  129. self.stepwise_param = copy.deepcopy(stepwise_param)
  130. self.early_stopping_rounds = early_stopping_rounds
  131. self.metrics = metrics or []
  132. self.use_first_metric_only = use_first_metric_only
  133. self.floating_point_precision = floating_point_precision
  134. self.callback_param = copy.deepcopy(callback_param)
  135. def check(self):
  136. descr = "logistic_param's"
  137. super(LogisticParam, self).check()
  138. self.predict_param.check()
  139. if self.encrypt_param.method not in [consts.PAILLIER, consts.PAILLIER_IPCL, None]:
  140. raise ValueError(
  141. "logistic_param's encrypted method support 'Paillier' or None only")
  142. self.multi_class = self.check_and_change_lower(
  143. self.multi_class, ["ovr"], f"{descr}")
  144. if not isinstance(self.masked_rate, (float, int)) or self.masked_rate < 0:
  145. raise ValueError(
  146. "masked rate should be non-negative numeric number")
  147. if not isinstance(self.batch_strategy, str) or self.batch_strategy.lower() not in ["full", "random"]:
  148. raise ValueError("batch strategy should be full or random")
  149. self.batch_strategy = self.batch_strategy.lower()
  150. if not isinstance(self.shuffle, bool):
  151. raise ValueError("shuffle should be boolean type")
  152. return True
  153. class HomoLogisticParam(LogisticParam):
  154. """
  155. Parameters
  156. ----------
  157. aggregate_iters : int, default: 1
  158. Indicate how many iterations are aggregated once.
  159. """
  160. def __init__(self, penalty='L2',
  161. tol=1e-4, alpha=1.0, optimizer='rmsprop',
  162. batch_size=-1, learning_rate=0.01, init_param=InitParam(),
  163. max_iter=100, early_stop='diff',
  164. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  165. decay=1, decay_sqrt=True,
  166. aggregate_iters=1, multi_class='ovr', validation_freqs=None,
  167. metrics=['auc', 'ks'],
  168. callback_param=CallbackParam()
  169. ):
  170. super(HomoLogisticParam, self).__init__(penalty=penalty, tol=tol, alpha=alpha, optimizer=optimizer,
  171. batch_size=batch_size,
  172. learning_rate=learning_rate,
  173. init_param=init_param, max_iter=max_iter, early_stop=early_stop,
  174. predict_param=predict_param,
  175. cv_param=cv_param, multi_class=multi_class,
  176. validation_freqs=validation_freqs,
  177. decay=decay, decay_sqrt=decay_sqrt,
  178. metrics=metrics,
  179. callback_param=callback_param)
  180. self.aggregate_iters = aggregate_iters
  181. def check(self):
  182. super().check()
  183. if not isinstance(self.aggregate_iters, int):
  184. raise ValueError(
  185. "logistic_param's aggregate_iters {} not supported, should be int type".format(
  186. self.aggregate_iters))
  187. return True
  188. class HeteroLogisticParam(LogisticParam):
  189. def __init__(self, penalty='L2',
  190. tol=1e-4, alpha=1.0, optimizer='rmsprop',
  191. batch_size=-1, shuffle=True, batch_strategy="full", masked_rate=5,
  192. learning_rate=0.01, init_param=InitParam(),
  193. max_iter=100, early_stop='diff',
  194. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  195. predict_param=PredictParam(), cv_param=CrossValidationParam(),
  196. decay=1, decay_sqrt=True, sqn_param=StochasticQuasiNewtonParam(),
  197. multi_class='ovr', validation_freqs=None, early_stopping_rounds=None,
  198. metrics=['auc', 'ks'], floating_point_precision=23,
  199. encrypt_param=EncryptParam(),
  200. use_first_metric_only=False, stepwise_param=StepwiseParam(),
  201. callback_param=CallbackParam()
  202. ):
  203. super(HeteroLogisticParam, self).__init__(penalty=penalty, tol=tol, alpha=alpha, optimizer=optimizer,
  204. batch_size=batch_size, shuffle=shuffle, batch_strategy=batch_strategy,
  205. masked_rate=masked_rate,
  206. learning_rate=learning_rate,
  207. init_param=init_param, max_iter=max_iter, early_stop=early_stop,
  208. predict_param=predict_param, cv_param=cv_param,
  209. decay=decay,
  210. decay_sqrt=decay_sqrt, multi_class=multi_class,
  211. validation_freqs=validation_freqs,
  212. early_stopping_rounds=early_stopping_rounds,
  213. metrics=metrics, floating_point_precision=floating_point_precision,
  214. encrypt_param=encrypt_param,
  215. use_first_metric_only=use_first_metric_only,
  216. stepwise_param=stepwise_param,
  217. callback_param=callback_param)
  218. self.encrypted_mode_calculator_param = copy.deepcopy(
  219. encrypted_mode_calculator_param)
  220. self.sqn_param = copy.deepcopy(sqn_param)
  221. def check(self):
  222. super().check()
  223. self.encrypted_mode_calculator_param.check()
  224. self.sqn_param.check()
  225. return True