cross_validation_param.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from pipeline.param.base_param import BaseParam
  18. # from pipeline.param.evaluation_param import EvaluateParam
  19. from pipeline.param import consts
  20. class CrossValidationParam(BaseParam):
  21. """
  22. Define cross validation params
  23. Parameters
  24. ----------
  25. n_splits: int, default: 5
  26. Specify how many splits used in KFold
  27. mode: str, default: 'Hetero'
  28. Indicate what mode is current task
  29. role: {'Guest', 'Host', 'Arbiter'}, default: 'Guest'
  30. Indicate what role is current party
  31. shuffle: bool, default: True
  32. Define whether do shuffle before KFold or not.
  33. random_seed: int, default: 1
  34. Specify the random seed for numpy shuffle
  35. need_cv: bool, default False
  36. Indicate if this module needed to be run
  37. output_fold_history: bool, default True
  38. Indicate whether to output table of ids used by each fold, else return original input data
  39. returned ids are formatted as: {original_id}#fold{fold_num}#{train/validate}
  40. history_value_type: {'score', 'instance'}, default score
  41. Indicate whether to include original instance or predict score in the output fold history,
  42. only effective when output_fold_history set to True
  43. """
  44. def __init__(self, n_splits=5, mode=consts.HETERO, role=consts.GUEST, shuffle=True, random_seed=1,
  45. need_cv=False, output_fold_history=True, history_value_type="score"):
  46. super(CrossValidationParam, self).__init__()
  47. self.n_splits = n_splits
  48. self.mode = mode
  49. self.role = role
  50. self.shuffle = shuffle
  51. self.random_seed = random_seed
  52. # self.evaluate_param = copy.deepcopy(evaluate_param)
  53. self.need_cv = need_cv
  54. self.output_fold_history = output_fold_history
  55. self.history_value_type = history_value_type
  56. def check(self):
  57. model_param_descr = "cross validation param's "
  58. self.check_positive_integer(self.n_splits, model_param_descr)
  59. self.check_valid_value(self.mode, model_param_descr, valid_values=[consts.HOMO, consts.HETERO])
  60. self.check_valid_value(self.role, model_param_descr, valid_values=[consts.HOST, consts.GUEST, consts.ARBITER])
  61. self.check_boolean(self.shuffle, model_param_descr)
  62. self.check_boolean(self.output_fold_history, model_param_descr)
  63. self.history_value_type = self.check_and_change_lower(
  64. self.history_value_type, ["instance", "score"], model_param_descr)
  65. if self.random_seed is not None:
  66. self.check_positive_integer(self.random_seed, model_param_descr)