pearson_param.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from pipeline.param.base_param import BaseParam
  19. class PearsonParam(BaseParam):
  20. def __init__(
  21. self,
  22. column_names=None,
  23. column_indexes=None,
  24. cross_parties=True,
  25. need_run=True,
  26. use_mix_rand=False,
  27. calc_local_vif=True,
  28. ):
  29. super().__init__()
  30. self.column_names = column_names
  31. self.column_indexes = column_indexes
  32. self.cross_parties = cross_parties
  33. self.need_run = need_run
  34. self.use_mix_rand = use_mix_rand
  35. if column_names is None:
  36. self.column_names = []
  37. if column_indexes is None:
  38. self.column_indexes = []
  39. self.calc_local_vif = calc_local_vif
  40. def check(self):
  41. if not isinstance(self.use_mix_rand, bool):
  42. raise ValueError(
  43. f"use_mix_rand accept bool type only, {type(self.use_mix_rand)} got"
  44. )
  45. if self.cross_parties and (not self.need_run):
  46. raise ValueError(
  47. f"need_run should be True(which is default) when cross_parties is True."
  48. )
  49. if not isinstance(self.column_names, list):
  50. raise ValueError(
  51. f"type mismatch, column_names with type {type(self.column_names)}"
  52. )
  53. for name in self.column_names:
  54. if not isinstance(name, str):
  55. raise ValueError(
  56. f"type mismatch, column_names with element {name}(type is {type(name)})"
  57. )
  58. if isinstance(self.column_indexes, list):
  59. for idx in self.column_indexes:
  60. if not isinstance(idx, int):
  61. raise ValueError(
  62. f"type mismatch, column_indexes with element {idx}(type is {type(idx)})"
  63. )
  64. if isinstance(self.column_indexes, int) and self.column_indexes != -1:
  65. raise ValueError(
  66. f"column_indexes with type int and value {self.column_indexes}(only -1 allowed)"
  67. )
  68. if self.need_run:
  69. if isinstance(self.column_indexes, list) and isinstance(
  70. self.column_names, list
  71. ):
  72. if len(self.column_indexes) == 0 and len(self.column_names) == 0:
  73. raise ValueError(f"provide at least one column")