pearson_param.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from federatedml.param.base_param import BaseParam
  19. class PearsonParam(BaseParam):
  20. """
  21. param for pearson correlation
  22. Parameters
  23. ----------
  24. column_names : list of string
  25. list of column names
  26. column_index : list of int
  27. list of column index
  28. cross_parties : bool, default: True
  29. if True, calculate correlation of columns from both party
  30. need_run : bool
  31. set False to skip this party
  32. use_mix_rand : bool, defalut: False
  33. mix system random and pseudo random for quicker calculation
  34. calc_loca_vif : bool, default True
  35. calculate VIF for columns in local
  36. """
  37. def __init__(
  38. self,
  39. column_names=None,
  40. column_indexes=None,
  41. cross_parties=True,
  42. need_run=True,
  43. use_mix_rand=False,
  44. calc_local_vif=True,
  45. ):
  46. super().__init__()
  47. self.column_names = column_names
  48. self.column_indexes = column_indexes
  49. self.cross_parties = cross_parties
  50. self.need_run = need_run
  51. self.use_mix_rand = use_mix_rand
  52. self.calc_local_vif = calc_local_vif
  53. def check(self):
  54. if not isinstance(self.use_mix_rand, bool):
  55. raise ValueError(
  56. f"use_mix_rand accept bool type only, {type(self.use_mix_rand)} got"
  57. )
  58. if self.cross_parties and (not self.need_run):
  59. raise ValueError(
  60. f"need_run should be True(which is default) when cross_parties is True."
  61. )
  62. self.column_indexes = [] if self.column_indexes is None else self.column_indexes
  63. self.column_names = [] if self.column_names is None else self.column_names
  64. if not isinstance(self.column_names, list):
  65. raise ValueError(
  66. f"type mismatch, column_names with type {type(self.column_names)}"
  67. )
  68. for name in self.column_names:
  69. if not isinstance(name, str):
  70. raise ValueError(
  71. f"type mismatch, column_names with element {name}(type is {type(name)})"
  72. )
  73. if isinstance(self.column_indexes, list):
  74. for idx in self.column_indexes:
  75. if not isinstance(idx, int):
  76. raise ValueError(
  77. f"type mismatch, column_indexes with element {idx}(type is {type(idx)})"
  78. )
  79. if isinstance(self.column_indexes, int) and self.column_indexes != -1:
  80. raise ValueError(
  81. f"column_indexes with type int and value {self.column_indexes}(only -1 allowed)"
  82. )
  83. if self.need_run:
  84. if isinstance(self.column_indexes, list) and isinstance(
  85. self.column_names, list
  86. ):
  87. if len(self.column_indexes) == 0 and len(self.column_names) == 0:
  88. raise ValueError(f"provide at least one column")