glm_param.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import copy
  19. from federatedml.param.base_param import BaseParam, deprecated_param
  20. from federatedml.param.callback_param import CallbackParam
  21. from federatedml.param.encrypt_param import EncryptParam
  22. from federatedml.param.cross_validation_param import CrossValidationParam
  23. from federatedml.param.init_model_param import InitParam
  24. from federatedml.param.stepwise_param import StepwiseParam
  25. from federatedml.util import consts
  26. @deprecated_param("validation_freqs", "metrics", "early_stopping_rounds", "use_first_metric_only")
  27. class LinearModelParam(BaseParam):
  28. """
  29. Parameters used for GLM.
  30. Parameters
  31. ----------
  32. penalty : {'L2' or 'L1'}
  33. Penalty method used in LinR. Please note that, when using encrypted version in HeteroLinR,
  34. 'L1' is not supported.
  35. tol : float, default: 1e-4
  36. The tolerance of convergence
  37. alpha : float, default: 1.0
  38. Regularization strength coefficient.
  39. optimizer : {'sgd', 'rmsprop', 'adam', 'sqn', 'adagrad', 'nesterov_momentum_sgd'}
  40. Optimize method
  41. batch_size : int, default: -1
  42. Batch size when updating model. -1 means use all data in a batch. i.e. Not to use mini-batch strategy.
  43. learning_rate : float, default: 0.01
  44. Learning rate
  45. max_iter : int, default: 20
  46. The maximum iteration for training.
  47. init_param: InitParam object, default: default InitParam object
  48. Init param method object.
  49. early_stop : {'diff', 'abs', 'weight_dff'}
  50. Method used to judge convergence.
  51. a) diff: Use difference of loss between two iterations to judge whether converge.
  52. b) abs: Use the absolute value of loss to judge whether converge. i.e. if loss < tol, it is converged.
  53. c) weight_diff: Use difference between weights of two consecutive iterations
  54. encrypt_param: EncryptParam object, default: default EncryptParam object
  55. encrypt param
  56. cv_param: CrossValidationParam object, default: default CrossValidationParam object
  57. cv param
  58. decay: int or float, default: 1
  59. Decay rate for learning rate. learning rate will follow the following decay schedule.
  60. lr = lr0/(1+decay*t) if decay_sqrt is False. If decay_sqrt is True, lr = lr0 / sqrt(1+decay*t)
  61. where t is the iter number.
  62. decay_sqrt: Bool, default: True
  63. lr = lr0/(1+decay*t) if decay_sqrt is False, otherwise, lr = lr0 / sqrt(1+decay*t)
  64. validation_freqs: int, list, tuple, set, or None
  65. validation frequency during training, required when using early stopping.
  66. The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to speed up training by skipping validation rounds.
  67. When it is larger than 1, a number which is divisible by "max_iter" is recommended, otherwise, you will miss the validation scores of the last training iteration.
  68. early_stopping_rounds: int, default: None
  69. If positive number specified, at every specified training rounds, program checks for early stopping criteria.
  70. Validation_freqs must also be set when using early stopping.
  71. metrics: list or None, default: None
  72. Specify which metrics to be used when performing evaluation during training process. If metrics have not improved at early_stopping rounds, trianing stops before convergence.
  73. If set as empty, default metrics will be used. For regression tasks, default metrics are ['root_mean_squared_error', 'mean_absolute_error']
  74. use_first_metric_only: bool, default: False
  75. Indicate whether to use the first metric in `metrics` as the only criterion for early stopping judgement.
  76. floating_point_precision: None or integer
  77. if not None, use floating_point_precision-bit to speed up calculation,
  78. e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
  79. the result by 2**floating_point_precision in the end.
  80. callback_param: CallbackParam object
  81. callback param
  82. """
  83. def __init__(self, penalty='L2',
  84. tol=1e-4, alpha=1.0, optimizer='sgd',
  85. batch_size=-1, learning_rate=0.01, init_param=InitParam(),
  86. max_iter=100, early_stop='diff',
  87. encrypt_param=EncryptParam(),
  88. cv_param=CrossValidationParam(), decay=1, decay_sqrt=True, validation_freqs=None,
  89. early_stopping_rounds=None, stepwise_param=StepwiseParam(), metrics=None, use_first_metric_only=False,
  90. floating_point_precision=23, callback_param=CallbackParam()):
  91. super(LinearModelParam, self).__init__()
  92. self.penalty = penalty
  93. self.tol = tol
  94. self.alpha = alpha
  95. self.optimizer = optimizer
  96. self.batch_size = batch_size
  97. self.learning_rate = learning_rate
  98. self.init_param = copy.deepcopy(init_param)
  99. self.max_iter = max_iter
  100. self.early_stop = early_stop
  101. self.encrypt_param = encrypt_param
  102. self.cv_param = copy.deepcopy(cv_param)
  103. self.decay = decay
  104. self.decay_sqrt = decay_sqrt
  105. self.validation_freqs = validation_freqs
  106. self.early_stopping_rounds = early_stopping_rounds
  107. self.stepwise_param = copy.deepcopy(stepwise_param)
  108. self.metrics = metrics or []
  109. self.use_first_metric_only = use_first_metric_only
  110. self.floating_point_precision = floating_point_precision
  111. self.callback_param = copy.deepcopy(callback_param)
  112. def check(self):
  113. descr = "linear model param's "
  114. if self.penalty is None:
  115. self.penalty = 'NONE'
  116. elif type(self.penalty).__name__ != "str":
  117. raise ValueError(
  118. descr + "penalty {} not supported, should be str type".format(self.penalty))
  119. self.penalty = self.penalty.upper()
  120. if self.penalty not in [consts.L1_PENALTY, consts.L2_PENALTY, consts.NONE.upper()]:
  121. raise ValueError(
  122. "penalty {} not supported, penalty should be 'L1', 'L2' or 'NONE'".format(self.penalty))
  123. if type(self.tol).__name__ not in ["int", "float"]:
  124. raise ValueError(
  125. descr + "tol {} not supported, should be float type".format(self.tol))
  126. if type(self.alpha).__name__ not in ["int", "float"]:
  127. raise ValueError(
  128. descr + "alpha {} not supported, should be float type".format(self.alpha))
  129. if type(self.optimizer).__name__ != "str":
  130. raise ValueError(
  131. descr + "optimizer {} not supported, should be str type".format(self.optimizer))
  132. else:
  133. self.optimizer = self.optimizer.lower()
  134. if self.optimizer not in ['sgd', 'rmsprop', 'adam', 'adagrad', 'sqn', 'nesterov_momentum_sgd']:
  135. raise ValueError(
  136. descr + "optimizer not supported, optimizer should be"
  137. " 'sgd', 'rmsprop', 'adam', 'sqn', 'adagrad', or 'nesterov_momentum_sgd'")
  138. if type(self.batch_size).__name__ not in ["int", "long"]:
  139. raise ValueError(
  140. descr + "batch_size {} not supported, should be int type".format(self.batch_size))
  141. if self.batch_size != -1:
  142. if type(self.batch_size).__name__ not in ["int", "long"] \
  143. or self.batch_size < consts.MIN_BATCH_SIZE:
  144. raise ValueError(descr + " {} not supported, should be larger than {} or "
  145. "-1 represent for all data".format(self.batch_size, consts.MIN_BATCH_SIZE))
  146. if type(self.learning_rate).__name__ not in ["int", "float"]:
  147. raise ValueError(
  148. descr + "learning_rate {} not supported, should be float type".format(
  149. self.learning_rate))
  150. self.init_param.check()
  151. if type(self.max_iter).__name__ != "int":
  152. raise ValueError(
  153. descr + "max_iter {} not supported, should be int type".format(self.max_iter))
  154. elif self.max_iter <= 0:
  155. raise ValueError(
  156. descr + "max_iter must be greater or equal to 1")
  157. if type(self.early_stop).__name__ != "str":
  158. raise ValueError(
  159. descr + "early_stop {} not supported, should be str type".format(
  160. self.early_stop))
  161. else:
  162. self.early_stop = self.early_stop.lower()
  163. if self.early_stop not in ['diff', 'abs', 'weight_diff']:
  164. raise ValueError(
  165. descr + "early_stop not supported, early_stop should be 'weight_diff', 'diff' or 'abs'")
  166. self.encrypt_param.check()
  167. if type(self.decay).__name__ not in ["int", "float"]:
  168. raise ValueError(
  169. descr + "decay {} not supported, should be 'int' or 'float'".format(self.decay)
  170. )
  171. if type(self.decay_sqrt).__name__ not in ["bool"]:
  172. raise ValueError(
  173. descr + "decay_sqrt {} not supported, should be 'bool'".format(self.decay)
  174. )
  175. self.stepwise_param.check()
  176. for p in ["early_stopping_rounds", "validation_freqs", "metrics",
  177. "use_first_metric_only"]:
  178. if self._warn_to_deprecate_param(p, "", ""):
  179. if "callback_param" in self.get_user_feeded():
  180. raise ValueError(f"{p} and callback param should not be set simultaneously")
  181. else:
  182. self.callback_param.callbacks = ["PerformanceEvaluate"]
  183. break
  184. if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
  185. self.callback_param.validation_freqs = self.validation_freqs
  186. if self._warn_to_deprecate_param("early_stopping_rounds", descr, "callback_param's 'early_stopping_rounds'"):
  187. self.callback_param.early_stopping_rounds = self.early_stopping_rounds
  188. if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
  189. self.callback_param.metrics = self.metrics
  190. if self._warn_to_deprecate_param("use_first_metric_only", descr, "callback_param's 'use_first_metric_only'"):
  191. self.callback_param.use_first_metric_only = self.use_first_metric_only
  192. if self.floating_point_precision is not None and \
  193. (not isinstance(self.floating_point_precision, int) or
  194. self.floating_point_precision < 0 or self.floating_point_precision > 64):
  195. raise ValueError("floating point precision should be null or a integer between 0 and 64")
  196. self.callback_param.check()
  197. return True