# # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from federatedml.model_base import ModelBase from federatedml.param.feature_binning_param import HomoFeatureBinningParam from federatedml.feature.homo_feature_binning import virtual_summary_binning, recursive_query_binning from federatedml.util import consts from federatedml.feature.hetero_feature_binning.base_feature_binning import BaseFeatureBinning from federatedml.transfer_variable.transfer_class.homo_binning_transfer_variable import HomoBinningTransferVariable class HomoBinningArbiter(BaseFeatureBinning): def __init__(self): super().__init__() self.binning_obj = None self.transfer_variable = HomoBinningTransferVariable() self.model_param = HomoFeatureBinningParam() def _init_model(self, model_param): self.model_param = model_param if self.model_param.method == consts.VIRTUAL_SUMMARY: self.binning_obj = virtual_summary_binning.Server(self.model_param) elif self.model_param.method == consts.RECURSIVE_QUERY: self.binning_obj = recursive_query_binning.Server(self.model_param) else: raise ValueError(f"Method: {self.model_param.method} cannot be recognized") def fit(self, *args): self.binning_obj.set_transfer_variable(self.transfer_variable) self.binning_obj.fit_split_points() def transform(self, data_instances): pass class HomoBinningClient(BaseFeatureBinning): def __init__(self): super().__init__() self.binning_obj = None self.transfer_variable = HomoBinningTransferVariable() self.model_param = HomoFeatureBinningParam() def _init_model(self, model_param: HomoFeatureBinningParam): self.transform_type = self.model_param.transform_param.transform_type self.model_param = model_param if self.model_param.method == consts.VIRTUAL_SUMMARY: self.binning_obj = virtual_summary_binning.Client(self.model_param) elif self.model_param.method == consts.RECURSIVE_QUERY: self.binning_obj = recursive_query_binning.Client(role=self.component_properties.role, params=self.model_param ) else: raise ValueError(f"Method: {self.model_param.method} cannot be recognized") def fit(self, data_instances): self._abnormal_detection(data_instances) self._setup_bin_inner_param(data_instances, self.model_param) transformed_instances = data_instances.mapValues(self.data_format_transform) transformed_instances.schema = self.schema self.binning_obj.set_bin_inner_param(self.bin_inner_param) self.binning_obj.set_transfer_variable(self.transfer_variable) split_points = self.binning_obj.fit_split_points(transformed_instances) data_out = self.transform(data_instances) summary = {} for k, v in split_points.items(): summary[k] = list(v) self.set_summary({"split_points": summary}) return data_out def transform(self, data_instances): return self.transform_data(data_instances)