data_split_param.py 4.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.param.base_param import BaseParam
  19. from federatedml.util import LOGGER
  20. class DataSplitParam(BaseParam):
  21. """
  22. Define data split param that used in data split.
  23. Parameters
  24. ----------
  25. random_state : None or int, default: None
  26. Specify the random state for shuffle.
  27. test_size : float or int or None, default: 0.0
  28. Specify test data set size.
  29. float value specifies fraction of input data set, int value specifies exact number of data instances
  30. train_size : float or int or None, default: 0.8
  31. Specify train data set size.
  32. float value specifies fraction of input data set, int value specifies exact number of data instances
  33. validate_size : float or int or None, default: 0.2
  34. Specify validate data set size.
  35. float value specifies fraction of input data set, int value specifies exact number of data instances
  36. stratified : bool, default: False
  37. Define whether sampling should be stratified, according to label value.
  38. shuffle : bool, default: True
  39. Define whether do shuffle before splitting or not.
  40. split_points : None or list, default : None
  41. Specify the point(s) by which continuous label values are bucketed into bins for stratified split.
  42. eg.[0.2] for two bins or [0.1, 1, 3] for 4 bins
  43. need_run: bool, default: True
  44. Specify whether to run data split
  45. """
  46. def __init__(self, random_state=None, test_size=None, train_size=None, validate_size=None, stratified=False,
  47. shuffle=True, split_points=None, need_run=True):
  48. super(DataSplitParam, self).__init__()
  49. self.random_state = random_state
  50. self.test_size = test_size
  51. self.train_size = train_size
  52. self.validate_size = validate_size
  53. self.stratified = stratified
  54. self.shuffle = shuffle
  55. self.split_points = split_points
  56. self.need_run = need_run
  57. def check(self):
  58. model_param_descr = "data split param's "
  59. if self.random_state is not None:
  60. if not isinstance(self.random_state, int):
  61. raise ValueError(f"{model_param_descr} random state should be int type")
  62. BaseParam.check_nonnegative_number(self.random_state, f"{model_param_descr} random_state ")
  63. if self.test_size is not None:
  64. BaseParam.check_nonnegative_number(self.test_size, f"{model_param_descr} test_size ")
  65. if isinstance(self.test_size, float):
  66. BaseParam.check_decimal_float(self.test_size, f"{model_param_descr} test_size ")
  67. if self.train_size is not None:
  68. BaseParam.check_nonnegative_number(self.train_size, f"{model_param_descr} train_size ")
  69. if isinstance(self.train_size, float):
  70. BaseParam.check_decimal_float(self.train_size, f"{model_param_descr} train_size ")
  71. if self.validate_size is not None:
  72. BaseParam.check_nonnegative_number(self.validate_size, f"{model_param_descr} validate_size ")
  73. if isinstance(self.validate_size, float):
  74. BaseParam.check_decimal_float(self.validate_size, f"{model_param_descr} validate_size ")
  75. # use default size values if none given
  76. if self.test_size is None and self.train_size is None and self.validate_size is None:
  77. self.test_size = 0.0
  78. self.train_size = 0.8
  79. self.validate_size = 0.2
  80. BaseParam.check_boolean(self.stratified, f"{model_param_descr} stratified ")
  81. BaseParam.check_boolean(self.shuffle, f"{model_param_descr} shuffle ")
  82. BaseParam.check_boolean(self.need_run, f"{model_param_descr} need run ")
  83. if self.split_points is not None:
  84. if not isinstance(self.split_points, list):
  85. raise ValueError(f"{model_param_descr} split_points should be list type")
  86. LOGGER.debug("Finish data_split parameter check!")
  87. return True