sample_param.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from pipeline.param.base_param import BaseParam
  19. import collections
  20. class SampleParam(BaseParam):
  21. """
  22. Define the sample method
  23. Parameters
  24. ----------
  25. mode: str, accepted 'random','stratified', 'exact_by_weight', specify sample to use, default: 'random'
  26. method: str, accepted 'downsample','upsample' only in this version. default: 'downsample'
  27. fractions: None or float or list, if mode equals to random, it should be a float number greater than 0,
  28. otherwise a list of elements of pairs like [label_i, sample_rate_i],
  29. e.g. [[0, 0.5], [1, 0.8], [2, 0.3]]. default: None
  30. random_state: int, RandomState instance or None, default: None
  31. need_run: bool, default True
  32. Indicate if this module needed to be run
  33. """
  34. def __init__(self, mode="random", method="downsample", fractions=None,
  35. random_state=None, task_type="hetero", need_run=True):
  36. self.mode = mode
  37. self.method = method
  38. self.fractions = fractions
  39. self.random_state = random_state
  40. self.task_type = task_type
  41. self.need_run = need_run
  42. def check(self):
  43. descr = "sample param"
  44. self.mode = self.check_and_change_lower(self.mode,
  45. ["random", "stratified", "exact_by_weight"],
  46. descr)
  47. self.method = self.check_and_change_lower(self.method,
  48. ["upsample", "downsample"],
  49. descr)
  50. if self.mode == "stratified" and self.fractions is not None:
  51. if not isinstance(self.fractions, list):
  52. raise ValueError("fractions of sample param when using stratified should be list")
  53. for ele in self.fractions:
  54. if not isinstance(ele, collections.Container) or len(ele) != 2:
  55. raise ValueError(
  56. "element in fractions of sample param using stratified should be a pair like [label_i, rate_i]")
  57. return True