123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import functools
- import operator
- from federatedml.cipher_compressor.compressor import CipherCompressorHost
- from federatedml.feature.hetero_feature_binning.base_feature_binning import BaseFeatureBinning
- from federatedml.secureprotol.fate_paillier import PaillierEncryptedNumber
- from federatedml.util import LOGGER
- from federatedml.util import consts
- class HeteroFeatureBinningHost(BaseFeatureBinning):
- def __init__(self):
- super(HeteroFeatureBinningHost, self).__init__()
- self.compressor = None
- def fit(self, data_instances):
- self._abnormal_detection(data_instances)
- # self._parse_cols(data_instances)
- self._setup_bin_inner_param(data_instances, self.model_param)
- if self.model_param.method == consts.OPTIMAL:
- has_missing_value = self.iv_calculator.check_containing_missing_value(data_instances)
- for idx in self.bin_inner_param.bin_indexes:
- if idx in has_missing_value:
- raise ValueError(f"Optimal Binning do not support missing value now.")
- if self.model_param.split_points_by_col_name or self.model_param.split_points_by_index:
- split_points = self._get_manual_split_points(data_instances)
- self.use_manual_split_points = True
- for col_name, sp in split_points.items():
- self.binning_obj.bin_results.put_col_split_points(col_name, sp)
- else:
- # Calculates split points of data in self part
- split_points = self.binning_obj.fit_split_points(data_instances)
- return self.stat_and_transform(data_instances, split_points)
- def transform(self, data_instances):
- self._setup_bin_inner_param(data_instances, self.model_param)
- split_points = self.binning_obj.bin_results.all_split_points
- return self.stat_and_transform(data_instances, split_points)
- def stat_and_transform(self, data_instances, split_points):
- """
- Apply binning method for both data instances in local party as well as the other one. Afterwards, calculate
- the specific metric value for specific columns.
- """
- if self.model_param.skip_static:
- # if self.transform_type != 'woe':
- data_instances = self.transform_data(data_instances)
- """
- else:
- raise ValueError("Woe transform is not supported in host parties.")
- """
- self.set_schema(data_instances)
- self.data_output = data_instances
- return data_instances
- if not self.model_param.local_only:
- has_label = True
- if self._stage == "transform":
- has_label = self.transfer_variable.transform_stage_has_label.get(idx=0)
- if has_label:
- self.compressor = CipherCompressorHost()
- self._sync_init_bucket(data_instances, split_points)
- if self.model_param.method == consts.OPTIMAL and self._stage == "fit":
- self.optimal_binning_sync()
- # if self.transform_type != 'woe':
- data_instances = self.transform_data(data_instances)
- self.set_schema(data_instances)
- self.data_output = data_instances
- total_summary = self.binning_obj.bin_results.to_json()
- self.set_summary(total_summary)
- return data_instances
- def _sync_init_bucket(self, data_instances, split_points, need_shuffle=False):
- data_bin_table = self.binning_obj.get_data_bin(data_instances, split_points, self.bin_inner_param.bin_cols_map)
- # LOGGER.debug("data_bin_table, count: {}".format(data_bin_table.count()))
- encrypted_label_table = self.transfer_variable.encrypted_label.get(idx=0)
- LOGGER.info("Get encrypted_label_table from guest")
- encrypted_bin_sum = self.__static_encrypted_bin_label(data_bin_table, encrypted_label_table)
- encrypted_bin_sum = self.compressor.compress_dtable(encrypted_bin_sum)
- encode_name_f = functools.partial(self.bin_inner_param.change_to_anonymous,
- col_name_anonymous_maps=self.bin_inner_param.col_name_anonymous_maps)
- # encrypted_bin_sum = self.bin_inner_param.encode_col_name_dict(encrypted_bin_sum, self)
- encrypted_bin_sum = encrypted_bin_sum.map(encode_name_f)
- # encrypted_bin_sum = self.cipher_compress(encrypted_bin_sum, data_bin_table.count())
- self.transfer_variable.encrypted_bin_sum.remote(encrypted_bin_sum,
- role=consts.GUEST,
- idx=0)
- send_result = {
- "category_names": self.bin_inner_param.get_anonymous_col_name_list(
- self.bin_inner_param.category_names),
- "bin_method": self.model_param.method,
- "optimal_params": {
- "metric_method": self.model_param.optimal_binning_param.metric_method,
- "bin_num": self.model_param.bin_num,
- "mixture": self.model_param.optimal_binning_param.mixture,
- "max_bin_pct": self.model_param.optimal_binning_param.max_bin_pct,
- "min_bin_pct": self.model_param.optimal_binning_param.min_bin_pct
- }
- }
- self.transfer_variable.optimal_info.remote(send_result,
- role=consts.GUEST,
- idx=0)
- def __static_encrypted_bin_label(self, data_bin_table, encrypted_label):
- # data_bin_with_label = data_bin_table.join(encrypted_label, lambda x, y: (x, y))
- label_counts = encrypted_label.reduce(operator.add)
- sparse_bin_points = self.binning_obj.get_sparse_bin(self.bin_inner_param.bin_indexes,
- self.binning_obj.bin_results.all_split_points,
- self.bin_inner_param.header)
- sparse_bin_points = {self.bin_inner_param.header[k]: v for k, v in sparse_bin_points.items()}
- encrypted_bin_sum = self.iv_calculator.cal_bin_label(
- data_bin_table=data_bin_table,
- sparse_bin_points=sparse_bin_points,
- label_table=encrypted_label,
- label_counts=label_counts
- )
- return encrypted_bin_sum
- @staticmethod
- def convert_compress_format(col_name, encrypted_bin_sum):
- """
- Parameters
- ----------
- encrypted_bin_sum : list.
- It is like:
- {'x1': [[event_count, non_event_count], [event_count, non_event_count] ... ],
- 'x2': [[event_count, non_event_count], [event_count, non_event_count] ... ],
- ...
- }
- returns
- -------
- {"keys": ['x1', 'x2' ...],
- "event_counts": [...],
- "non_event_counts": [...],
- "bin_num": [...]
- }
- """
- event_counts = [x[0] for x in encrypted_bin_sum]
- non_event_counts = [x[1] for x in encrypted_bin_sum]
- return col_name, {"event_counts": event_counts, "non_event_counts": non_event_counts}
- def optimal_binning_sync(self):
- bucket_idx = self.transfer_variable.bucket_idx.get(idx=0)
- # LOGGER.debug("In optimal_binning_sync, received bucket_idx: {}".format(bucket_idx))
- original_split_points = self.binning_obj.bin_results.all_split_points
- for anonymous_col_name, b_idx in bucket_idx.items():
- col_name = self.bin_inner_param.get_col_name_by_anonymous(anonymous_col_name)
- ori_sp_list = original_split_points.get(col_name)
- optimal_result = [ori_sp_list[i] for i in b_idx]
- self.binning_obj.bin_results.put_col_split_points(col_name, optimal_result)
|