base_feature_binning.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import numpy as np
  19. from federatedml.feature.binning.base_binning import BaseBinning
  20. from federatedml.feature.binning.bin_inner_param import BinInnerParam
  21. from federatedml.feature.binning.bucket_binning import BucketBinning
  22. from federatedml.feature.binning.optimal_binning.optimal_binning import OptimalBinning
  23. from federatedml.feature.binning.quantile_binning import QuantileBinning
  24. from federatedml.feature.binning.iv_calculator import IvCalculator
  25. from federatedml.feature.binning.bin_result import MultiClassBinResult
  26. from federatedml.feature.fate_element_type import NoneType
  27. from federatedml.feature.sparse_vector import SparseVector
  28. from federatedml.model_base import ModelBase
  29. from federatedml.param.feature_binning_param import HeteroFeatureBinningParam as FeatureBinningParam
  30. from federatedml.protobuf.generated import feature_binning_meta_pb2, feature_binning_param_pb2
  31. from federatedml.statistic.data_overview import get_header, get_anonymous_header
  32. from federatedml.transfer_variable.transfer_class.hetero_feature_binning_transfer_variable import \
  33. HeteroFeatureBinningTransferVariable
  34. from federatedml.util import LOGGER
  35. from federatedml.util import abnormal_detection
  36. from federatedml.util import consts
  37. from federatedml.util.anonymous_generator_util import Anonymous
  38. from federatedml.util.io_check import assert_io_num_rows_equal
  39. from federatedml.util.schema_check import assert_schema_consistent
  40. MODEL_PARAM_NAME = 'FeatureBinningParam'
  41. MODEL_META_NAME = 'FeatureBinningMeta'
  42. class BaseFeatureBinning(ModelBase):
  43. """
  44. Do binning method through guest and host
  45. """
  46. def __init__(self):
  47. super(BaseFeatureBinning, self).__init__()
  48. self.transfer_variable = HeteroFeatureBinningTransferVariable()
  49. self.binning_obj: BaseBinning = None
  50. self.header = None
  51. self.anonymous_header = None
  52. self.training_anonymous_header = None
  53. self.schema = None
  54. self.host_results = []
  55. self.transform_host_results = []
  56. self.transform_type = None
  57. self.model_param = FeatureBinningParam()
  58. self.bin_inner_param = BinInnerParam()
  59. self.bin_result = MultiClassBinResult(labels=[0, 1])
  60. self.transform_bin_result = MultiClassBinResult(labels=[0, 1])
  61. self.has_missing_value = False
  62. self.labels = []
  63. self.use_manual_split_points = False
  64. self.has_woe_array = False
  65. self._stage = "fit"
  66. def _init_model(self, params: FeatureBinningParam):
  67. self.model_param = params
  68. self.transform_type = self.model_param.transform_param.transform_type
  69. """
  70. if self.role == consts.HOST:
  71. if self.transform_type == "woe":
  72. raise ValueError("Host party do not support woe transform now.")
  73. """
  74. if self.model_param.method == consts.QUANTILE:
  75. self.binning_obj = QuantileBinning(self.model_param)
  76. elif self.model_param.method == consts.BUCKET:
  77. self.binning_obj = BucketBinning(self.model_param)
  78. elif self.model_param.method == consts.OPTIMAL:
  79. if self.role == consts.HOST:
  80. self.model_param.bin_num = self.model_param.optimal_binning_param.init_bin_nums
  81. self.binning_obj = QuantileBinning(self.model_param)
  82. else:
  83. self.binning_obj = OptimalBinning(self.model_param)
  84. else:
  85. raise ValueError(f"Binning method: {self.model_param.method} is not supported.")
  86. self.iv_calculator = IvCalculator(self.model_param.adjustment_factor,
  87. role=self.role,
  88. party_id=self.component_properties.local_partyid)
  89. def _get_manual_split_points(self, data_instances):
  90. data_index_to_col_name = dict(enumerate(data_instances.schema.get("header")))
  91. manual_split_points = {}
  92. if self.model_param.split_points_by_index is not None:
  93. manual_split_points = {
  94. data_index_to_col_name.get(int(k), None): v for k, v in self.model_param.split_points_by_index.items()
  95. }
  96. if None in manual_split_points.keys():
  97. raise ValueError(f"Index given in `split_points_by_index` not found in input data header."
  98. f"Please check.")
  99. if self.model_param.split_points_by_col_name is not None:
  100. for col_name, split_points in self.model_param.split_points_by_col_name.items():
  101. if manual_split_points.get(col_name) is not None:
  102. raise ValueError(f"Split points for feature {col_name} given in both "
  103. f"`split_points_by_index` and `split_points_by_col_name`. Please check.")
  104. manual_split_points[col_name] = split_points
  105. if set(self.bin_inner_param.bin_names) != set(manual_split_points.keys()):
  106. raise ValueError(f"Column set from provided split points dictionary does not match that of"
  107. f"`bin_names` or `bin_indexes. Please check.`")
  108. return manual_split_points
  109. @staticmethod
  110. def data_format_transform(row):
  111. """
  112. transform data into sparse format
  113. """
  114. if type(row.features).__name__ != consts.SPARSE_VECTOR:
  115. feature_shape = row.features.shape[0]
  116. indices = []
  117. data = []
  118. for i in range(feature_shape):
  119. if np.isnan(row.features[i]):
  120. indices.append(i)
  121. data.append(NoneType())
  122. elif np.abs(row.features[i]) < consts.FLOAT_ZERO:
  123. continue
  124. else:
  125. indices.append(i)
  126. data.append(row.features[i])
  127. new_row = copy.deepcopy(row)
  128. new_row.features = SparseVector(indices, data, feature_shape)
  129. return new_row
  130. else:
  131. sparse_vec = row.features.get_sparse_vector()
  132. replace_key = []
  133. for key in sparse_vec:
  134. if sparse_vec.get(key) == NoneType() or np.isnan(sparse_vec.get(key)):
  135. replace_key.append(key)
  136. if len(replace_key) == 0:
  137. return row
  138. else:
  139. new_row = copy.deepcopy(row)
  140. new_sparse_vec = new_row.features.get_sparse_vector()
  141. for key in replace_key:
  142. new_sparse_vec[key] = NoneType()
  143. return new_row
  144. def _setup_bin_inner_param(self, data_instances, params):
  145. if self.schema is not None:
  146. return
  147. self.header = get_header(data_instances)
  148. self.anonymous_header = get_anonymous_header(data_instances)
  149. LOGGER.debug("_setup_bin_inner_param, get header length: {}".format(len(self.header)))
  150. self.schema = data_instances.schema
  151. self.bin_inner_param.set_header(self.header, self.anonymous_header)
  152. if params.bin_indexes == -1:
  153. self.bin_inner_param.set_bin_all()
  154. else:
  155. self.bin_inner_param.add_bin_indexes(params.bin_indexes)
  156. self.bin_inner_param.add_bin_names(params.bin_names)
  157. self.bin_inner_param.add_category_indexes(params.category_indexes)
  158. self.bin_inner_param.add_category_names(params.category_names)
  159. if params.transform_param.transform_cols == -1:
  160. self.bin_inner_param.set_transform_all()
  161. else:
  162. self.bin_inner_param.add_transform_bin_indexes(params.transform_param.transform_cols)
  163. self.bin_inner_param.add_transform_bin_names(params.transform_param.transform_names)
  164. self.binning_obj.set_bin_inner_param(self.bin_inner_param)
  165. @assert_io_num_rows_equal
  166. @assert_schema_consistent
  167. def transform_data(self, data_instances):
  168. self._setup_bin_inner_param(data_instances, self.model_param)
  169. if self.transform_type != "woe":
  170. data_instances = self.binning_obj.transform(data_instances, self.transform_type)
  171. elif self.role == consts.HOST and not self.has_woe_array:
  172. raise ValueError("Woe transform is not available for host parties.")
  173. else:
  174. data_instances = self.iv_calculator.woe_transformer(data_instances, self.bin_inner_param,
  175. self.bin_result)
  176. self.set_schema(data_instances)
  177. self.data_output = data_instances
  178. return data_instances
  179. def _get_meta(self):
  180. # col_list = [str(x) for x in self.cols]
  181. transform_param = feature_binning_meta_pb2.TransformMeta(
  182. transform_cols=self.bin_inner_param.transform_bin_indexes,
  183. transform_type=self.model_param.transform_param.transform_type
  184. )
  185. optimal_metric_method = None
  186. if self.model_param.method == consts.OPTIMAL and not self.use_manual_split_points:
  187. optimal_metric_method = self.model_param.optimal_binning_param.metric_method
  188. meta_protobuf_obj = feature_binning_meta_pb2.FeatureBinningMeta(
  189. method=self.model_param.method,
  190. compress_thres=self.model_param.compress_thres,
  191. head_size=self.model_param.head_size,
  192. error=self.model_param.error,
  193. bin_num=self.model_param.bin_num,
  194. cols=self.bin_inner_param.bin_names,
  195. adjustment_factor=self.model_param.adjustment_factor,
  196. local_only=self.model_param.local_only,
  197. need_run=self.need_run,
  198. transform_param=transform_param,
  199. skip_static=self.model_param.skip_static,
  200. optimal_metric_method=optimal_metric_method
  201. )
  202. return meta_protobuf_obj
  203. def _get_param(self):
  204. split_points_result = self.binning_obj.bin_results.split_results
  205. multi_class_result = self.bin_result.generated_pb_list(split_points_result)
  206. # LOGGER.debug(f"split_points_result: {split_points_result}")
  207. host_multi_class_result = []
  208. host_single_results = []
  209. anonymous_dict_list = []
  210. if self._stage == "transform" and self._check_lower_version_anonymous():
  211. if self.role == consts.GUEST:
  212. anonymous_dict_list = self.transfer_variable.host_anonymous_header_dict.get(idx=-1)
  213. elif self.role == consts.HOST:
  214. anonymous_dict = dict(zip(self.training_anonymous_header, self.anonymous_header))
  215. self.transfer_variable.host_anonymous_header_dict.remote(
  216. anonymous_dict,
  217. role=consts.GUEST,
  218. idx=0
  219. )
  220. for idx, host_res in enumerate(self.host_results):
  221. if not anonymous_dict_list:
  222. host_multi_class_result.extend(host_res.generated_pb_list())
  223. host_single_results.append(host_res.bin_results[0].generated_pb())
  224. else:
  225. updated_anonymous_header = anonymous_dict_list[idx]
  226. host_res.update_anonymous(updated_anonymous_header)
  227. host_multi_class_result.extend(host_res.generated_pb_list())
  228. host_single_results.append(host_res.bin_results[0].generated_pb())
  229. has_host_result = True if len(host_multi_class_result) else False
  230. multi_pb = feature_binning_param_pb2.MultiClassResult(
  231. results=multi_class_result,
  232. labels=[str(x) for x in self.labels],
  233. host_results=host_multi_class_result,
  234. host_party_ids=[str(x) for x in self.component_properties.host_party_idlist],
  235. has_host_result=has_host_result
  236. )
  237. if self._stage == "fit":
  238. result_obj = feature_binning_param_pb2. \
  239. FeatureBinningParam(binning_result=multi_class_result[0],
  240. host_results=host_single_results,
  241. header=self.header,
  242. header_anonymous=self.anonymous_header,
  243. model_name=consts.BINNING_MODEL,
  244. multi_class_result=multi_pb)
  245. else:
  246. transform_multi_class_result = self.transform_bin_result.generated_pb_list(split_points_result)
  247. transform_host_single_results = []
  248. transform_host_multi_class_result = []
  249. for host_res in self.transform_host_results:
  250. transform_host_multi_class_result.extend(host_res.generated_pb_list())
  251. transform_host_single_results.append(host_res.bin_results[0].generated_pb())
  252. transform_multi_pb = feature_binning_param_pb2.MultiClassResult(
  253. results=transform_multi_class_result,
  254. labels=[str(x) for x in self.labels],
  255. host_results=transform_host_multi_class_result,
  256. host_party_ids=[str(x) for x in self.component_properties.host_party_idlist],
  257. has_host_result=has_host_result
  258. )
  259. result_obj = feature_binning_param_pb2. \
  260. FeatureBinningParam(binning_result=multi_class_result[0],
  261. host_results=host_single_results,
  262. header=self.header,
  263. header_anonymous=self.anonymous_header,
  264. model_name=consts.BINNING_MODEL,
  265. multi_class_result=multi_pb,
  266. transform_binning_result=transform_multi_class_result[0],
  267. transform_host_results=transform_host_single_results,
  268. transform_multi_class_result=transform_multi_pb)
  269. return result_obj
  270. def load_model(self, model_dict):
  271. model_param = list(model_dict.get('model').values())[0].get(MODEL_PARAM_NAME)
  272. model_meta = list(model_dict.get('model').values())[0].get(MODEL_META_NAME)
  273. self.bin_inner_param = BinInnerParam()
  274. multi_class_result = model_param.multi_class_result
  275. self.labels = list(map(int, multi_class_result.labels))
  276. if self.labels:
  277. self.bin_result = MultiClassBinResult.reconstruct(list(multi_class_result.results), self.labels)
  278. if self.role == consts.HOST:
  279. binning_result = dict(list(multi_class_result.results)[0].binning_result)
  280. woe_array = list(binning_result.values())[0].woe_array
  281. # if manual woe, reconstruct
  282. if woe_array:
  283. self.bin_result = MultiClassBinResult.reconstruct(list(multi_class_result.results))
  284. self.has_woe_array = True
  285. assert isinstance(model_meta, feature_binning_meta_pb2.FeatureBinningMeta)
  286. assert isinstance(model_param, feature_binning_param_pb2.FeatureBinningParam)
  287. self.header = list(model_param.header)
  288. self.training_anonymous_header = list(model_param.header_anonymous)
  289. self.bin_inner_param.set_header(self.header, self.training_anonymous_header)
  290. self.bin_inner_param.add_transform_bin_indexes(list(model_meta.transform_param.transform_cols))
  291. self.bin_inner_param.add_bin_names(list(model_meta.cols))
  292. self.transform_type = model_meta.transform_param.transform_type
  293. bin_method = str(model_meta.method)
  294. if bin_method == consts.QUANTILE:
  295. self.binning_obj = QuantileBinning(params=model_meta)
  296. elif bin_method == consts.OPTIMAL:
  297. self.binning_obj = OptimalBinning(params=model_meta)
  298. else:
  299. self.binning_obj = BucketBinning(params=model_meta)
  300. # self.binning_obj.set_role_party(self.role, self.component_properties.local_partyid)
  301. self.binning_obj.set_bin_inner_param(self.bin_inner_param)
  302. split_results = dict(model_param.binning_result.binning_result)
  303. for col_name, sr_pb in split_results.items():
  304. split_points = list(sr_pb.split_points)
  305. self.binning_obj.bin_results.put_col_split_points(col_name, split_points)
  306. # self.binning_obj.bin_results.reconstruct(model_param.binning_result)
  307. self.host_results = []
  308. host_pbs = list(model_param.multi_class_result.host_results)
  309. if len(host_pbs):
  310. if len(self.labels) == 2:
  311. for host_pb in host_pbs:
  312. self.host_results.append(MultiClassBinResult.reconstruct(
  313. host_pb, self.labels))
  314. else:
  315. assert len(host_pbs) % len(self.labels) == 0
  316. i = 0
  317. while i < len(host_pbs):
  318. this_pbs = host_pbs[i: i + len(self.labels)]
  319. self.host_results.append(MultiClassBinResult.reconstruct(this_pbs, self.labels))
  320. i += len(self.labels)
  321. """
  322. if list(model_param.header_anonymous):
  323. self.anonymous_header = list(model_param.anonymous_header)
  324. """
  325. self._stage = "transform"
  326. def export_model(self):
  327. if self.model_output is not None:
  328. return self.model_output
  329. meta_obj = self._get_meta()
  330. param_obj = self._get_param()
  331. result = {
  332. MODEL_META_NAME: meta_obj,
  333. MODEL_PARAM_NAME: param_obj
  334. }
  335. self.model_output = result
  336. return result
  337. def save_data(self):
  338. return self.data_output
  339. def set_schema(self, data_instance):
  340. self.schema['header'] = self.header
  341. data_instance.schema = self.schema
  342. # LOGGER.debug("After Binning, when setting schema, schema is : {}".format(data_instance.schema))
  343. def set_optimal_metric_array(self, optimal_metric_array_dict):
  344. # LOGGER.debug(f"optimal metric array dict: {optimal_metric_array_dict}")
  345. for col_name, optimal_metric_array in optimal_metric_array_dict.items():
  346. self.bin_result.put_optimal_metric_array(col_name, optimal_metric_array)
  347. # LOGGER.debug(f"after set optimal metric, self.bin_result metric is: {self.bin_result.all_optimal_metric}")
  348. def _abnormal_detection(self, data_instances):
  349. """
  350. Make sure input data_instances is valid.
  351. """
  352. abnormal_detection.empty_table_detection(data_instances)
  353. abnormal_detection.empty_feature_detection(data_instances)
  354. self.check_schema_content(data_instances.schema)
  355. def _check_lower_version_anonymous(self):
  356. return not self.training_anonymous_header or \
  357. Anonymous.is_old_version_anonymous_header(self.training_anonymous_header)