123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from federatedml.model_base import ModelBase
- from federatedml.param.feature_binning_param import HomoFeatureBinningParam
- from federatedml.feature.homo_feature_binning import virtual_summary_binning, recursive_query_binning
- from federatedml.util import consts
- from federatedml.feature.hetero_feature_binning.base_feature_binning import BaseFeatureBinning
- from federatedml.transfer_variable.transfer_class.homo_binning_transfer_variable import HomoBinningTransferVariable
- class HomoBinningArbiter(BaseFeatureBinning):
- def __init__(self):
- super().__init__()
- self.binning_obj = None
- self.transfer_variable = HomoBinningTransferVariable()
- self.model_param = HomoFeatureBinningParam()
- def _init_model(self, model_param):
- self.model_param = model_param
- if self.model_param.method == consts.VIRTUAL_SUMMARY:
- self.binning_obj = virtual_summary_binning.Server(self.model_param)
- elif self.model_param.method == consts.RECURSIVE_QUERY:
- self.binning_obj = recursive_query_binning.Server(self.model_param)
- else:
- raise ValueError(f"Method: {self.model_param.method} cannot be recognized")
- def fit(self, *args):
- self.binning_obj.set_transfer_variable(self.transfer_variable)
- self.binning_obj.fit_split_points()
- def transform(self, data_instances):
- pass
- class HomoBinningClient(BaseFeatureBinning):
- def __init__(self):
- super().__init__()
- self.binning_obj = None
- self.transfer_variable = HomoBinningTransferVariable()
- self.model_param = HomoFeatureBinningParam()
- def _init_model(self, model_param: HomoFeatureBinningParam):
- self.transform_type = self.model_param.transform_param.transform_type
- self.model_param = model_param
- if self.model_param.method == consts.VIRTUAL_SUMMARY:
- self.binning_obj = virtual_summary_binning.Client(self.model_param)
- elif self.model_param.method == consts.RECURSIVE_QUERY:
- self.binning_obj = recursive_query_binning.Client(role=self.component_properties.role,
- params=self.model_param
- )
- else:
- raise ValueError(f"Method: {self.model_param.method} cannot be recognized")
- def fit(self, data_instances):
- self._abnormal_detection(data_instances)
- self._setup_bin_inner_param(data_instances, self.model_param)
- transformed_instances = data_instances.mapValues(self.data_format_transform)
- transformed_instances.schema = self.schema
- self.binning_obj.set_bin_inner_param(self.bin_inner_param)
- self.binning_obj.set_transfer_variable(self.transfer_variable)
- split_points = self.binning_obj.fit_split_points(transformed_instances)
- data_out = self.transform(data_instances)
- summary = {}
- for k, v in split_points.items():
- summary[k] = list(v)
- self.set_summary({"split_points": summary})
- return data_out
- def transform(self, data_instances):
- return self.transform_data(data_instances)
|