sample_param.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.param.base_param import BaseParam
  19. import collections
  20. class SampleParam(BaseParam):
  21. """
  22. Define the sample method
  23. Parameters
  24. ----------
  25. mode: {'random', 'stratified', 'exact_by_weight'}'
  26. specify sample to use, default: 'random'
  27. method: {'downsample', 'upsample'}, default: 'downsample'
  28. specify sample method
  29. fractions: None or float or list
  30. if mode equals to random, it should be a float number greater than 0,
  31. otherwise a list of elements of pairs like [label_i, sample_rate_i],
  32. e.g. [[0, 0.5], [1, 0.8], [2, 0.3]]. default: None
  33. random_state: int, RandomState instance or None, default: None
  34. random state
  35. need_run: bool, default True
  36. Indicate if this module needed to be run
  37. """
  38. def __init__(self, mode="random", method="downsample", fractions=None,
  39. random_state=None, task_type="hetero", need_run=True):
  40. self.mode = mode
  41. self.method = method
  42. self.fractions = fractions
  43. self.random_state = random_state
  44. self.task_type = task_type
  45. self.need_run = need_run
  46. def check(self):
  47. descr = "sample param"
  48. self.mode = self.check_and_change_lower(self.mode,
  49. ["random", "stratified", "exact_by_weight"],
  50. descr)
  51. self.method = self.check_and_change_lower(self.method,
  52. ["upsample", "downsample"],
  53. descr)
  54. if self.mode == "stratified" and self.fractions is not None:
  55. if not isinstance(self.fractions, list):
  56. raise ValueError("fractions of sample param when using stratified should be list")
  57. for ele in self.fractions:
  58. if not isinstance(ele, collections.Container) or len(ele) != 2:
  59. raise ValueError(
  60. "element in fractions of sample param using stratified should be a pair like [label_i, rate_i]")
  61. return True