123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from federatedml.model_base import MetricMeta
- from federatedml.feature.feature_scale.min_max_scale import MinMaxScale
- from federatedml.feature.feature_scale.standard_scale import StandardScale
- from federatedml.model_base import ModelBase
- from federatedml.param.scale_param import ScaleParam
- from federatedml.util import consts
- from federatedml.util import LOGGER
- from federatedml.util.io_check import assert_io_num_rows_equal
- from federatedml.util.schema_check import assert_schema_consistent
- class Scale(ModelBase):
- """
- The Scale class is used to data scale. MinMaxScale and StandardScale is supported now
- """
- def __init__(self):
- super().__init__()
- self.model_name = None
- self.model_param_name = 'ScaleParam'
- self.model_meta_name = 'ScaleMeta'
- self.model_param = ScaleParam()
- self.scale_param_obj = None
- self.scale_obj = None
- self.header = None
- self.column_max_value = None
- self.column_min_value = None
- self.mean = None
- self.std = None
- self.scale_column_idx = None
- def fit(self, data):
- """
- Apply scale for input data
- Parameters
- ----------
- data: data_instance, input data
- Returns
- ----------
- data:data_instance, data after scale
- scale_value_results: list, the fit results information of scale
- """
- LOGGER.info("Start scale data fit ...")
- if self.model_param.method == consts.MINMAXSCALE:
- self.scale_obj = MinMaxScale(self.model_param)
- elif self.model_param.method == consts.STANDARDSCALE:
- self.scale_obj = StandardScale(self.model_param)
- else:
- LOGGER.warning("Scale method is {}, do nothing and return!".format(self.model_param.method))
- if self.scale_obj:
- fit_data = self.scale_obj.fit(data)
- fit_data.schema = data.schema
- self.callback_meta(metric_name="scale", metric_namespace="train",
- metric_meta=MetricMeta(name="scale", metric_type="SCALE",
- extra_metas={"method": self.model_param.method}))
- LOGGER.info("start to get model summary ...")
- self.set_summary(self.scale_obj.get_model_summary())
- LOGGER.info("Finish getting model summary.")
- else:
- fit_data = data
- LOGGER.info("End fit data ...")
- return fit_data
- @assert_io_num_rows_equal
- @assert_schema_consistent
- def transform(self, data, fit_config=None):
- """
- Transform input data using scale with fit results
- Parameters
- ----------
- data: data_instance, input data
- fit_config: list, the fit results information of scale
- Returns
- ----------
- transform_data:data_instance, data after transform
- """
- LOGGER.info("Start scale data transform ...")
- if self.model_param.method == consts.MINMAXSCALE:
- self.scale_obj = MinMaxScale(self.model_param)
- elif self.model_param.method == consts.STANDARDSCALE:
- self.scale_obj = StandardScale(self.model_param)
- self.scale_obj.set_param(self.mean, self.std)
- else:
- LOGGER.info("DataTransform method is {}, do nothing and return!".format(self.model_param.method))
- if self.scale_obj:
- self.scale_obj.header = self.header
- self.scale_obj.scale_column_idx = self.scale_column_idx
- self.scale_obj.set_column_range(self.column_max_value, self.column_min_value)
- transform_data = self.scale_obj.transform(data)
- transform_data.schema = data.schema
- self.callback_meta(metric_name="scale", metric_namespace="train",
- metric_meta=MetricMeta(name="scale", metric_type="SCALE",
- extra_metas={"method": self.model_param.method}))
- else:
- transform_data = data
- LOGGER.info("End transform data.")
- return transform_data
- def load_model(self, model_dict):
- model_obj = list(model_dict.get('model').values())[0].get(self.model_param_name)
- meta_obj = list(model_dict.get('model').values())[0].get(self.model_meta_name)
- self.header = list(model_obj.header)
- self.need_run = meta_obj.need_run
- self.model_param.method = meta_obj.method
- shape = len(self.header)
- self.column_max_value = [0 for _ in range(shape)]
- self.column_min_value = [0 for _ in range(shape)]
- self.mean = [0 for _ in range(shape)]
- self.std = [1 for _ in range(shape)]
- self.scale_column_idx = []
- scale_param_dict = dict(model_obj.col_scale_param)
- header_index_mapping = dict(zip(self.header, range(len(self.header))))
- for key, column_scale_param in scale_param_dict.items():
- # index = self.header.index(key)
- index = header_index_mapping[key]
- self.scale_column_idx.append(index)
- self.column_max_value[index] = column_scale_param.column_upper
- self.column_min_value[index] = column_scale_param.column_lower
- self.mean[index] = column_scale_param.mean
- self.std[index] = column_scale_param.std
- self.scale_column_idx.sort()
- def export_model(self):
- if not self.scale_obj:
- if self.model_param.method == consts.MINMAXSCALE:
- self.scale_obj = MinMaxScale(self.model_param)
- else:
- self.scale_obj = StandardScale(self.model_param)
- return self.scale_obj.export_model(self.need_run)
|