123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- from federatedml.feature.binning.quantile_tool import QuantileBinningTool
- from federatedml.feature.homo_feature_binning import homo_binning_base
- from federatedml.param.feature_binning_param import HomoFeatureBinningParam
- from federatedml.util import consts
- from federatedml.util import LOGGER
- from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
- import copy
- import operator
- import functools
- class Server(homo_binning_base.Server):
- def __init__(self, params: HomoFeatureBinningParam, abnormal_list=None):
- super().__init__(params, abnormal_list)
- def fit_split_points(self, data=None):
- if self.aggregator is None:
- self.aggregator = SecureAggregatorServer(
- secure_aggregate=True, communicate_match_suffix='recursive_query_binning')
- self.get_total_count()
- self.get_min_max()
- self.get_missing_count()
- self.set_suffix(-1)
- self.query_values()
- n_iter = 0
- while n_iter < self.params.max_iter:
- self.set_suffix(n_iter)
- is_converge = self.transfer_variable.is_converge.get(suffix=self.suffix)[0]
- if is_converge:
- break
- self.query_values()
- n_iter += 1
- class Client(homo_binning_base.Client):
- def __init__(self, role, params: HomoFeatureBinningParam = None,
- abnormal_list=None, allow_duplicate=False):
- super().__init__(params, abnormal_list)
- self.allow_duplicate = allow_duplicate
- self.global_ranks = {}
- self.total_count = 0
- self.missing_counts = 0
- self.error = params.error
- self.error_rank = None
- self.role = role
- def fit_split_points(self, data_instances):
- if self.aggregator is None:
- self.aggregator = SecureAggregatorClient(
- secure_aggregate=True,
- aggregate_type='sum',
- communicate_match_suffix='recursive_query_binning')
- if self.bin_inner_param is None:
- self._setup_bin_inner_param(data_instances, self.params)
- self.total_count = self.get_total_count(data_instances)
- self.error_rank = np.ceil(self.error * self.total_count)
- LOGGER.debug(f"abnormal_list: {self.abnormal_list}")
- quantile_tool = QuantileBinningTool(param_obj=self.params,
- abnormal_list=self.abnormal_list,
- allow_duplicate=self.allow_duplicate)
- quantile_tool.set_bin_inner_param(self.bin_inner_param)
- summary_table = quantile_tool.fit_summary(data_instances)
- self.get_min_max(data_instances)
- self.missing_counts = self.get_missing_count(summary_table)
- split_points_table = self._recursive_querying(summary_table)
- split_points = dict(split_points_table.collect())
- for col_name, sps in split_points.items():
- sp = [x.value for x in sps]
- if not self.allow_duplicate:
- sp = sorted(set(sp))
- res = [sp[0] if np.fabs(sp[0]) > consts.FLOAT_ZERO else 0.0]
- last = sp[0]
- for v in sp[1:]:
- if np.fabs(v) < consts.FLOAT_ZERO:
- v = 0.0
- if np.abs(v - last) > consts.FLOAT_ZERO:
- res.append(v)
- last = v
- sp = np.array(res)
- self.bin_results.put_col_split_points(col_name, sp)
- return self.bin_results.all_split_points
- @staticmethod
- def _set_aim_rank(feature_name, split_point_array, missing_dict, total_counts, split_num):
- total_count = total_counts - missing_dict[feature_name]
- aim_ranks = [np.floor(x * total_count)
- for x in np.linspace(0, 1, split_num)]
- aim_ranks = aim_ranks[1:]
- for idx, sp in enumerate(split_point_array):
- sp.set_aim_rank(aim_ranks[idx])
- return feature_name, split_point_array
- def _recursive_querying(self, summary_table):
- self.set_suffix(-1)
- query_points_table = self.init_query_points(summary_table.partitions,
- split_num=self.params.bin_num + 1,
- error_rank=self.error_rank,
- need_first=False)
- f = functools.partial(self._set_aim_rank,
- missing_dict=self.missing_counts,
- total_counts=self.total_count,
- split_num=self.params.bin_num + 1)
- query_points_table = query_points_table.map(f)
- global_ranks = self.query_values(summary_table, query_points_table)
- n_iter = 0
- while n_iter < self.params.max_iter:
- self.set_suffix(n_iter)
- query_points_table = query_points_table.join(global_ranks, self.renew_query_points_table)
- is_converge = self.check_converge(query_points_table)
- if self.role == consts.GUEST:
- self.transfer_variable.is_converge.remote(is_converge, suffix=self.suffix)
- LOGGER.debug(f"n_iter: {n_iter}, converged: {is_converge}")
- if is_converge:
- break
- global_ranks = self.query_values(summary_table, query_points_table)
- n_iter += 1
- return query_points_table
- @staticmethod
- def renew_query_points_table(query_points, ranks):
- assert len(query_points) == len(ranks)
- new_array = []
- for idx, node in enumerate(query_points):
- rank = ranks[idx]
- if node.fixed:
- new_node = copy.deepcopy(node)
- elif rank - node.aim_rank > node.allow_error_rank:
- new_node = node.create_left_new()
- elif node.aim_rank - rank > node.allow_error_rank:
- new_node = node.create_right_new()
- else:
- new_node = copy.deepcopy(node)
- new_node.fixed = True
- new_node.last_rank = rank
- new_array.append(new_node)
- return new_array
- @staticmethod
- def check_converge(query_table):
- def is_all_fixed(node_array):
- fix_array = [n.fixed for n in node_array]
- return functools.reduce(operator.and_, fix_array)
- fix_table = query_table.mapValues(is_all_fixed)
- return fix_table.reduce(operator.and_)
|