poisson_regression_param.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import copy
  19. from pipeline.param.glm_param import LinearModelParam
  20. from pipeline.param.callback_param import CallbackParam
  21. from pipeline.param.encrypt_param import EncryptParam
  22. from pipeline.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  23. from pipeline.param.cross_validation_param import CrossValidationParam
  24. from pipeline.param.init_model_param import InitParam
  25. from pipeline.param.stepwise_param import StepwiseParam
  26. from pipeline.param import consts
  27. class PoissonParam(LinearModelParam):
  28. """
  29. Parameters used for Poisson Regression.
  30. Parameters
  31. ----------
  32. penalty : {'L2', 'L1'}, default: 'L2'
  33. Penalty method used in Poisson. Please note that, when using encrypted version in HeteroPoisson,
  34. 'L1' is not supported.
  35. tol : float, default: 1e-4
  36. The tolerance of convergence
  37. alpha : float, default: 1.0
  38. Regularization strength coefficient.
  39. optimizer : {'rmsprop', 'sgd', 'adam', 'adagrad'}, default: 'rmsprop'
  40. Optimize method
  41. batch_size : int, default: -1
  42. Batch size when updating model. -1 means use all data in a batch. i.e. Not to use mini-batch strategy.
  43. learning_rate : float, default: 0.01
  44. Learning rate
  45. max_iter : int, default: 20
  46. The maximum iteration for training.
  47. init_param: InitParam object, default: default InitParam object
  48. Init param method object.
  49. early_stop : str, 'weight_diff', 'diff' or 'abs', default: 'diff'
  50. Method used to judge convergence.
  51. a) diff: Use difference of loss between two iterations to judge whether converge.
  52. b) weight_diff: Use difference between weights of two consecutive iterations
  53. c) abs: Use the absolute value of loss to judge whether converge. i.e. if loss < eps, it is converged.
  54. exposure_colname: str or None, default: None
  55. Name of optional exposure variable in dTable.
  56. encrypt_param: EncryptParam object, default: default EncryptParam object
  57. encrypt param
  58. encrypted_mode_calculator_param: EncryptedModeCalculatorParam object, default: default EncryptedModeCalculatorParam object
  59. encrypted mode calculator param
  60. cv_param: CrossValidationParam object, default: default CrossValidationParam object
  61. cv param
  62. stepwise_param: StepwiseParam object, default: default StepwiseParam object
  63. stepwise param
  64. decay: int or float, default: 1
  65. Decay rate for learning rate. learning rate will follow the following decay schedule.
  66. lr = lr0/(1+decay*t) if decay_sqrt is False. If decay_sqrt is True, lr = lr0 / sqrt(1+decay*t)
  67. where t is the iter number.
  68. decay_sqrt: bool, default: True
  69. lr = lr0/(1+decay*t) if decay_sqrt is False, otherwise, lr = lr0 / sqrt(1+decay*t)
  70. validation_freqs: int, list, tuple, set, or None
  71. validation frequency during training, required when using early stopping.
  72. The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to speed up training by skipping validation rounds.
  73. When it is larger than 1, a number which is divisible by "max_iter" is recommended, otherwise, you will miss the validation scores of the last training iteration.
  74. early_stopping_rounds: int, default: None
  75. If positive number specified, at every specified training rounds, program checks for early stopping criteria.
  76. Validation_freqs must also be set when using early stopping.
  77. metrics: list or None, default: None
  78. Specify which metrics to be used when performing evaluation during training process. If metrics have not improved at early_stopping rounds, trianing stops before convergence.
  79. If set as empty, default metrics will be used. For regression tasks, default metrics are ['root_mean_squared_error', 'mean_absolute_error']
  80. use_first_metric_only: bool, default: False
  81. Indicate whether to use the first metric in `metrics` as the only criterion for early stopping judgement.
  82. floating_point_precision: None or integer
  83. if not None, use floating_point_precision-bit to speed up calculation,
  84. e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
  85. the result by 2**floating_point_precision in the end.
  86. callback_param: CallbackParam object
  87. callback param
  88. """
  89. def __init__(self, penalty='L2',
  90. tol=1e-4, alpha=1.0, optimizer='rmsprop',
  91. batch_size=-1, learning_rate=0.01, init_param=InitParam(),
  92. max_iter=20, early_stop='diff',
  93. exposure_colname=None,
  94. encrypt_param=EncryptParam(),
  95. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  96. cv_param=CrossValidationParam(), stepwise_param=StepwiseParam(),
  97. decay=1, decay_sqrt=True,
  98. validation_freqs=None, early_stopping_rounds=None, metrics=None, use_first_metric_only=False,
  99. floating_point_precision=23, callback_param=CallbackParam()):
  100. super(PoissonParam, self).__init__(penalty=penalty, tol=tol, alpha=alpha, optimizer=optimizer,
  101. batch_size=batch_size, learning_rate=learning_rate,
  102. init_param=init_param, max_iter=max_iter,
  103. early_stop=early_stop, cv_param=cv_param, decay=decay,
  104. decay_sqrt=decay_sqrt, validation_freqs=validation_freqs,
  105. early_stopping_rounds=early_stopping_rounds, metrics=metrics,
  106. floating_point_precision=floating_point_precision,
  107. encrypt_param=encrypt_param,
  108. use_first_metric_only=use_first_metric_only,
  109. stepwise_param=stepwise_param,
  110. callback_param=callback_param)
  111. self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
  112. self.exposure_colname = exposure_colname
  113. def check(self):
  114. descr = "poisson_regression_param's "
  115. super(PoissonParam, self).check()
  116. if self.encrypt_param.method != consts.PAILLIER:
  117. raise ValueError(
  118. descr + "encrypt method supports 'Paillier' only")
  119. if self.optimizer not in ['sgd', 'rmsprop', 'adam', 'adagrad']:
  120. raise ValueError(
  121. descr + "optimizer not supported, optimizer should be"
  122. " 'sgd', 'rmsprop', 'adam', or 'adagrad'")
  123. if self.exposure_colname is not None:
  124. if type(self.exposure_colname).__name__ != "str":
  125. raise ValueError(
  126. descr + "exposure_colname {} not supported, should be string type".format(self.exposure_colname))
  127. self.encrypted_mode_calculator_param.check()
  128. return True