callback_param.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.param.base_param import BaseParam
  19. class CallbackParam(BaseParam):
  20. """
  21. Define callback method that used in federated ml.
  22. Parameters
  23. ----------
  24. callbacks : list, default: []
  25. Indicate what kinds of callback functions is desired during the training process.
  26. Accepted values: {'EarlyStopping', 'ModelCheckpoint', 'PerformanceEvaluate'}
  27. validation_freqs: {None, int, list, tuple, set}
  28. validation frequency during training.
  29. early_stopping_rounds: None or int
  30. Will stop training if one metric doesn’t improve in last early_stopping_round rounds
  31. metrics: None, or list
  32. Indicate when executing evaluation during train process, which metrics will be used. If set as empty,
  33. default metrics for specific task type will be used. As for binary classification, default metrics are
  34. ['auc', 'ks']
  35. use_first_metric_only: bool, default: False
  36. Indicate whether use the first metric only for early stopping judgement.
  37. save_freq: int, default: 1
  38. The callbacks save model every save_freq epoch
  39. """
  40. def __init__(self, callbacks=None, validation_freqs=None, early_stopping_rounds=None,
  41. metrics=None, use_first_metric_only=False, save_freq=1):
  42. super(CallbackParam, self).__init__()
  43. self.callbacks = callbacks or []
  44. self.validation_freqs = validation_freqs
  45. self.early_stopping_rounds = early_stopping_rounds
  46. self.metrics = metrics or []
  47. self.use_first_metric_only = use_first_metric_only
  48. self.save_freq = save_freq
  49. def check(self):
  50. self.callbacks = [] if self.callbacks is None else self.callbacks
  51. self.metrics = [] if self.metrics is None else self.metrics
  52. if self.early_stopping_rounds is None:
  53. pass
  54. elif isinstance(self.early_stopping_rounds, int):
  55. if self.early_stopping_rounds < 1:
  56. raise ValueError("early stopping rounds should be larger than 0 when it's integer")
  57. if self.validation_freqs is None:
  58. raise ValueError("validation freqs must be set when early stopping is enabled")
  59. if self.validation_freqs is not None:
  60. if type(self.validation_freqs).__name__ not in ["int", "list", "tuple", "set"]:
  61. raise ValueError(
  62. "validation strategy param's validate_freqs's type not supported ,"
  63. " should be int or list or tuple or set"
  64. )
  65. if type(self.validation_freqs).__name__ == "int" and \
  66. self.validation_freqs <= 0:
  67. raise ValueError("validation strategy param's validate_freqs should greater than 0")
  68. if self.metrics is not None and not isinstance(self.metrics, list):
  69. raise ValueError("metrics should be a list")
  70. if not isinstance(self.use_first_metric_only, bool):
  71. raise ValueError("use_first_metric_only should be a boolean")
  72. return True