scale.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from federatedml.model_base import MetricMeta
  17. from federatedml.feature.feature_scale.min_max_scale import MinMaxScale
  18. from federatedml.feature.feature_scale.standard_scale import StandardScale
  19. from federatedml.model_base import ModelBase
  20. from federatedml.param.scale_param import ScaleParam
  21. from federatedml.util import consts
  22. from federatedml.util import LOGGER
  23. from federatedml.util.io_check import assert_io_num_rows_equal
  24. from federatedml.util.schema_check import assert_schema_consistent
  25. class Scale(ModelBase):
  26. """
  27. The Scale class is used to data scale. MinMaxScale and StandardScale is supported now
  28. """
  29. def __init__(self):
  30. super().__init__()
  31. self.model_name = None
  32. self.model_param_name = 'ScaleParam'
  33. self.model_meta_name = 'ScaleMeta'
  34. self.model_param = ScaleParam()
  35. self.scale_param_obj = None
  36. self.scale_obj = None
  37. self.header = None
  38. self.column_max_value = None
  39. self.column_min_value = None
  40. self.mean = None
  41. self.std = None
  42. self.scale_column_idx = None
  43. def fit(self, data):
  44. """
  45. Apply scale for input data
  46. Parameters
  47. ----------
  48. data: data_instance, input data
  49. Returns
  50. ----------
  51. data:data_instance, data after scale
  52. scale_value_results: list, the fit results information of scale
  53. """
  54. LOGGER.info("Start scale data fit ...")
  55. if self.model_param.method == consts.MINMAXSCALE:
  56. self.scale_obj = MinMaxScale(self.model_param)
  57. elif self.model_param.method == consts.STANDARDSCALE:
  58. self.scale_obj = StandardScale(self.model_param)
  59. else:
  60. LOGGER.warning("Scale method is {}, do nothing and return!".format(self.model_param.method))
  61. if self.scale_obj:
  62. fit_data = self.scale_obj.fit(data)
  63. fit_data.schema = data.schema
  64. self.callback_meta(metric_name="scale", metric_namespace="train",
  65. metric_meta=MetricMeta(name="scale", metric_type="SCALE",
  66. extra_metas={"method": self.model_param.method}))
  67. LOGGER.info("start to get model summary ...")
  68. self.set_summary(self.scale_obj.get_model_summary())
  69. LOGGER.info("Finish getting model summary.")
  70. else:
  71. fit_data = data
  72. LOGGER.info("End fit data ...")
  73. return fit_data
  74. @assert_io_num_rows_equal
  75. @assert_schema_consistent
  76. def transform(self, data, fit_config=None):
  77. """
  78. Transform input data using scale with fit results
  79. Parameters
  80. ----------
  81. data: data_instance, input data
  82. fit_config: list, the fit results information of scale
  83. Returns
  84. ----------
  85. transform_data:data_instance, data after transform
  86. """
  87. LOGGER.info("Start scale data transform ...")
  88. if self.model_param.method == consts.MINMAXSCALE:
  89. self.scale_obj = MinMaxScale(self.model_param)
  90. elif self.model_param.method == consts.STANDARDSCALE:
  91. self.scale_obj = StandardScale(self.model_param)
  92. self.scale_obj.set_param(self.mean, self.std)
  93. else:
  94. LOGGER.info("DataTransform method is {}, do nothing and return!".format(self.model_param.method))
  95. if self.scale_obj:
  96. self.scale_obj.header = self.header
  97. self.scale_obj.scale_column_idx = self.scale_column_idx
  98. self.scale_obj.set_column_range(self.column_max_value, self.column_min_value)
  99. transform_data = self.scale_obj.transform(data)
  100. transform_data.schema = data.schema
  101. self.callback_meta(metric_name="scale", metric_namespace="train",
  102. metric_meta=MetricMeta(name="scale", metric_type="SCALE",
  103. extra_metas={"method": self.model_param.method}))
  104. else:
  105. transform_data = data
  106. LOGGER.info("End transform data.")
  107. return transform_data
  108. def load_model(self, model_dict):
  109. model_obj = list(model_dict.get('model').values())[0].get(self.model_param_name)
  110. meta_obj = list(model_dict.get('model').values())[0].get(self.model_meta_name)
  111. self.header = list(model_obj.header)
  112. self.need_run = meta_obj.need_run
  113. self.model_param.method = meta_obj.method
  114. shape = len(self.header)
  115. self.column_max_value = [0 for _ in range(shape)]
  116. self.column_min_value = [0 for _ in range(shape)]
  117. self.mean = [0 for _ in range(shape)]
  118. self.std = [1 for _ in range(shape)]
  119. self.scale_column_idx = []
  120. scale_param_dict = dict(model_obj.col_scale_param)
  121. header_index_mapping = dict(zip(self.header, range(len(self.header))))
  122. for key, column_scale_param in scale_param_dict.items():
  123. # index = self.header.index(key)
  124. index = header_index_mapping[key]
  125. self.scale_column_idx.append(index)
  126. self.column_max_value[index] = column_scale_param.column_upper
  127. self.column_min_value[index] = column_scale_param.column_lower
  128. self.mean[index] = column_scale_param.mean
  129. self.std[index] = column_scale_param.std
  130. self.scale_column_idx.sort()
  131. def export_model(self):
  132. if not self.scale_obj:
  133. if self.model_param.method == consts.MINMAXSCALE:
  134. self.scale_obj = MinMaxScale(self.model_param)
  135. else:
  136. self.scale_obj = StandardScale(self.model_param)
  137. return self.scale_obj.export_model(self.need_run)