data_split_param.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from pipeline.param.base_param import BaseParam
  19. class DataSplitParam(BaseParam):
  20. """
  21. Define data split param that used in data split.
  22. Parameters
  23. ----------
  24. random_state : None or int, default: None
  25. Specify the random state for shuffle.
  26. test_size : float or int or None, default: 0.0
  27. Specify test data set size.
  28. float value specifies fraction of input data set, int value specifies exact number of data instances
  29. train_size : float or int or None, default: 0.8
  30. Specify train data set size.
  31. float value specifies fraction of input data set, int value specifies exact number of data instances
  32. validate_size : float or int or None, default: 0.2
  33. Specify validate data set size.
  34. float value specifies fraction of input data set, int value specifies exact number of data instances
  35. stratified : bool, default: False
  36. Define whether sampling should be stratified, according to label value.
  37. shuffle : bool, default: True
  38. Define whether do shuffle before splitting or not.
  39. split_points : None or list, default : None
  40. Specify the point(s) by which continuous label values are bucketed into bins for stratified split.
  41. eg.[0.2] for two bins or [0.1, 1, 3] for 4 bins
  42. need_run: bool, default: True
  43. Specify whether to run data split
  44. """
  45. def __init__(self, random_state=None, test_size=None, train_size=None, validate_size=None, stratified=False,
  46. shuffle=True, split_points=None, need_run=True):
  47. super(DataSplitParam, self).__init__()
  48. self.random_state = random_state
  49. self.test_size = test_size
  50. self.train_size = train_size
  51. self.validate_size = validate_size
  52. self.stratified = stratified
  53. self.shuffle = shuffle
  54. self.split_points = split_points
  55. self.need_run = need_run
  56. def check(self):
  57. model_param_descr = "data split param's "
  58. if self.random_state is not None:
  59. if not isinstance(self.random_state, int):
  60. raise ValueError(f"{model_param_descr} random state should be int type")
  61. BaseParam.check_nonnegative_number(self.random_state, f"{model_param_descr} random_state ")
  62. if self.test_size is not None:
  63. BaseParam.check_nonnegative_number(self.test_size, f"{model_param_descr} test_size ")
  64. if isinstance(self.test_size, float):
  65. BaseParam.check_decimal_float(self.test_size, f"{model_param_descr} test_size ")
  66. if self.train_size is not None:
  67. BaseParam.check_nonnegative_number(self.train_size, f"{model_param_descr} train_size ")
  68. if isinstance(self.train_size, float):
  69. BaseParam.check_decimal_float(self.train_size, f"{model_param_descr} train_size ")
  70. if self.validate_size is not None:
  71. BaseParam.check_nonnegative_number(self.validate_size, f"{model_param_descr} validate_size ")
  72. if isinstance(self.validate_size, float):
  73. BaseParam.check_decimal_float(self.validate_size, f"{model_param_descr} validate_size ")
  74. # use default size values if none given
  75. if self.test_size is None and self.train_size is None and self.validate_size is None:
  76. self.test_size = 0.0
  77. self.train_size = 0.8
  78. self.validate_size = 0.2
  79. BaseParam.check_boolean(self.stratified, f"{model_param_descr} stratified ")
  80. BaseParam.check_boolean(self.shuffle, f"{model_param_descr} shuffle ")
  81. BaseParam.check_boolean(self.need_run, f"{model_param_descr} need run ")
  82. if self.split_points is not None:
  83. if not isinstance(self.split_points, list):
  84. raise ValueError(f"{model_param_descr} split_points should be list type")
  85. return True