bin_inner_param.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. from federatedml.util import LOGGER
  19. class BinInnerParam(object):
  20. """
  21. Use to store columns related params for binning process
  22. """
  23. def __init__(self):
  24. self.bin_indexes = []
  25. self.bin_names = []
  26. self.bin_indexes_added_set = set()
  27. self.col_name_maps = {}
  28. self.anonymous_col_name_maps = {}
  29. self.col_name_anonymous_maps = {}
  30. self.header = []
  31. self.anonymous_header = []
  32. self.transform_bin_indexes = []
  33. self.transform_bin_names = []
  34. self.transform_bin_indexes_added_set = set()
  35. self.category_indexes = []
  36. self.category_names = []
  37. self.category_indexes_added_set = set()
  38. def set_header(self, header, anonymous_header):
  39. self.header = copy.deepcopy(header)
  40. self.anonymous_header = copy.deepcopy(anonymous_header)
  41. for idx, col_name in enumerate(self.header):
  42. self.col_name_maps[col_name] = idx
  43. self.anonymous_col_name_maps = dict(zip(self.anonymous_header, self.header))
  44. self.col_name_anonymous_maps = dict(zip(self.header, self.anonymous_header))
  45. def set_bin_all(self):
  46. """
  47. Called when user set to bin all columns
  48. """
  49. self.bin_indexes = [i for i in range(len(self.header))]
  50. self.bin_indexes_added_set = set(self.bin_indexes)
  51. self.bin_names = copy.deepcopy(self.header)
  52. def set_transform_all(self):
  53. self.transform_bin_indexes = self.bin_indexes
  54. self.transform_bin_names = self.bin_names
  55. self.transform_bin_indexes.extend(self.category_indexes)
  56. self.transform_bin_names.extend(self.category_names)
  57. self.transform_bin_indexes_added_set = set(self.transform_bin_indexes)
  58. def add_bin_indexes(self, bin_indexes):
  59. if bin_indexes is None:
  60. return
  61. for idx in bin_indexes:
  62. if idx >= len(self.header):
  63. # LOGGER.warning("Adding a index that out of header's bound")
  64. # continue
  65. raise ValueError("Adding a index that out of header's bound")
  66. if idx not in self.bin_indexes_added_set:
  67. self.bin_indexes.append(idx)
  68. self.bin_indexes_added_set.add(idx)
  69. self.bin_names.append(self.header[idx])
  70. def add_bin_names(self, bin_names):
  71. if bin_names is None:
  72. return
  73. for bin_name in bin_names:
  74. idx = self.col_name_maps.get(bin_name)
  75. if idx is None:
  76. LOGGER.warning("Adding a col_name that is not exist in header")
  77. continue
  78. if idx not in self.bin_indexes_added_set:
  79. self.bin_indexes.append(idx)
  80. self.bin_indexes_added_set.add(idx)
  81. self.bin_names.append(self.header[idx])
  82. def add_transform_bin_indexes(self, transform_indexes):
  83. if transform_indexes is None:
  84. return
  85. for idx in transform_indexes:
  86. if idx >= len(self.header) or idx < 0:
  87. raise ValueError("Adding a index that out of header's bound")
  88. # LOGGER.warning("Adding a index that out of header's bound")
  89. # continue
  90. if idx not in self.transform_bin_indexes_added_set:
  91. self.transform_bin_indexes.append(idx)
  92. self.transform_bin_indexes_added_set.add(idx)
  93. self.transform_bin_names.append(self.header[idx])
  94. def add_transform_bin_names(self, transform_names):
  95. if transform_names is None:
  96. return
  97. for bin_name in transform_names:
  98. idx = self.col_name_maps.get(bin_name)
  99. if idx is None:
  100. raise ValueError("Adding a col_name that is not exist in header")
  101. if idx not in self.transform_bin_indexes_added_set:
  102. self.transform_bin_indexes.append(idx)
  103. self.transform_bin_indexes_added_set.add(idx)
  104. self.transform_bin_names.append(self.header[idx])
  105. def add_category_indexes(self, category_indexes):
  106. if category_indexes == -1:
  107. category_indexes = [i for i in range(len(self.header))]
  108. elif category_indexes is None:
  109. return
  110. for idx in category_indexes:
  111. if idx >= len(self.header):
  112. LOGGER.warning("Adding a index that out of header's bound")
  113. continue
  114. if idx not in self.category_indexes_added_set:
  115. self.category_indexes.append(idx)
  116. self.category_indexes_added_set.add(idx)
  117. self.category_names.append(self.header[idx])
  118. if idx in self.bin_indexes_added_set:
  119. self.bin_indexes_added_set.remove(idx)
  120. self._align_bin_index()
  121. def add_category_names(self, category_names):
  122. if category_names is None:
  123. return
  124. for bin_name in category_names:
  125. idx = self.col_name_maps.get(bin_name)
  126. if idx is None:
  127. LOGGER.warning("Adding a col_name that is not exist in header")
  128. continue
  129. if idx not in self.category_indexes_added_set:
  130. self.category_indexes.append(idx)
  131. self.category_indexes_added_set.add(idx)
  132. self.category_names.append(self.header[idx])
  133. if idx in self.bin_indexes_added_set:
  134. self.bin_indexes_added_set.remove(idx)
  135. self._align_bin_index()
  136. def _align_bin_index(self):
  137. if len(self.bin_indexes_added_set) != len(self.bin_indexes):
  138. new_bin_indexes = []
  139. new_bin_names = []
  140. for idx in self.bin_indexes:
  141. if idx in self.bin_indexes_added_set:
  142. new_bin_indexes.append(idx)
  143. new_bin_names.append(self.header[idx])
  144. self.bin_indexes = new_bin_indexes
  145. self.bin_names = new_bin_names
  146. def get_need_cal_iv_cols_map(self):
  147. names = self.bin_names + self.category_names
  148. indexs = self.bin_indexes + self.category_indexes
  149. assert len(names) == len(indexs)
  150. return dict(zip(names, indexs))
  151. @property
  152. def bin_cols_map(self):
  153. assert len(self.bin_indexes) == len(self.bin_names)
  154. return dict(zip(self.bin_names, self.bin_indexes))
  155. @staticmethod
  156. def change_to_anonymous(col_name, v, col_name_anonymous_maps: dict):
  157. anonymous_col = col_name_anonymous_maps.get(col_name)
  158. return anonymous_col, v
  159. def get_anonymous_col_name_list(self, col_name_list: list):
  160. result = []
  161. for x in col_name_list:
  162. result.append(self.col_name_anonymous_maps[x])
  163. return result
  164. def get_col_name_by_anonymous(self, anonymous_col_name: str):
  165. return self.anonymous_col_name_maps.get(anonymous_col_name)