base_scale.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import functools
  18. from collections import Iterable
  19. from federatedml.statistic import data_overview
  20. from federatedml.statistic.data_overview import get_header
  21. from federatedml.statistic.statics import MultivariateStatisticalSummary
  22. from federatedml.util import consts
  23. from federatedml.util import LOGGER
  24. class BaseScale(object):
  25. def __init__(self, params):
  26. # self.area = params.area
  27. self.mode = params.mode
  28. self.param_scale_col_indexes = params.scale_col_indexes
  29. self.param_scale_names = params.scale_names
  30. self.feat_upper = params.feat_upper
  31. self.feat_lower = params.feat_lower
  32. self.data_shape = None
  33. self.header = None
  34. self.scale_column_idx = []
  35. self.summary_obj = None
  36. self.model_param_name = 'ScaleParam'
  37. self.model_meta_name = 'ScaleMeta'
  38. self.column_min_value = None
  39. self.column_max_value = None
  40. self.round_num = 6
  41. def _get_data_shape(self, data):
  42. if not self.data_shape:
  43. self.data_shape = data_overview.get_features_shape(data)
  44. return self.data_shape
  45. def _get_header(self, data):
  46. header = get_header(data)
  47. return header
  48. def _get_upper(self, data_shape):
  49. if isinstance(self.feat_upper, Iterable):
  50. return list(map(str, self.feat_upper))
  51. else:
  52. if self.feat_upper is None:
  53. return ["None" for _ in range(data_shape)]
  54. else:
  55. return [str(self.feat_upper) for _ in range(data_shape)]
  56. def _get_lower(self, data_shape):
  57. if isinstance(self.feat_lower, Iterable):
  58. return list(map(str, self.feat_lower))
  59. else:
  60. if self.feat_lower is None:
  61. return ["None" for _ in range(data_shape)]
  62. else:
  63. return [str(self.feat_lower) for _ in range(data_shape)]
  64. def _get_scale_column_idx(self, data):
  65. data_shape = self._get_data_shape(data)
  66. if self.param_scale_col_indexes != -1:
  67. if isinstance(self.param_scale_col_indexes, list):
  68. if len(self.param_scale_col_indexes) > 0:
  69. max_col_idx = max(self.param_scale_col_indexes)
  70. if max_col_idx >= data_shape:
  71. raise ValueError(
  72. "max column index in area is:{}, should less than data shape:{}".format(max_col_idx,
  73. data_shape))
  74. scale_column_idx = self.param_scale_col_indexes
  75. header = data_overview.get_header(data)
  76. scale_names = set(header).intersection(set(self.param_scale_names))
  77. idx_from_name = list(map(lambda n: header.index(n), scale_names))
  78. scale_column_idx = scale_column_idx + idx_from_name
  79. scale_column_idx = sorted(set(scale_column_idx))
  80. else:
  81. LOGGER.warning(
  82. "parameter scale_column_idx should be a list, but not:{}, set scale column to all columns".format(
  83. type(self.param_scale_col_indexes)))
  84. scale_column_idx = [i for i in range(data_shape)]
  85. else:
  86. scale_column_idx = [i for i in range(data_shape)]
  87. return scale_column_idx
  88. def __check_equal(self, size1, size2):
  89. if size1 != size2:
  90. raise ValueError("Check equal failed, {} != {}".format(size1, size2))
  91. def __get_min_max_value_by_normal(self, data):
  92. data_shape = self._get_data_shape(data)
  93. self.summary_obj = MultivariateStatisticalSummary(data, -1)
  94. header = data.schema.get("header")
  95. column_min_value = self.summary_obj.get_min()
  96. column_min_value = [column_min_value[key] for key in header]
  97. column_max_value = self.summary_obj.get_max()
  98. column_max_value = [column_max_value[key] for key in header]
  99. scale_column_idx_set = set(self._get_scale_column_idx(data))
  100. if self.feat_upper is not None:
  101. if isinstance(self.feat_upper, list):
  102. self.__check_equal(data_shape, len(self.feat_upper))
  103. for i in range(data_shape):
  104. if i in scale_column_idx_set:
  105. if column_max_value[i] > self.feat_upper[i]:
  106. column_max_value[i] = self.feat_upper[i]
  107. if column_min_value[i] > self.feat_upper[i]:
  108. column_min_value[i] = self.feat_upper[i]
  109. else:
  110. for i in range(data_shape):
  111. if i in scale_column_idx_set:
  112. if column_max_value[i] > self.feat_upper:
  113. column_max_value[i] = self.feat_upper
  114. if column_min_value[i] > self.feat_upper:
  115. column_min_value[i] = self.feat_upper
  116. if self.feat_lower is not None:
  117. if isinstance(self.feat_lower, list):
  118. self.__check_equal(data_shape, len(self.feat_lower))
  119. for i in range(data_shape):
  120. if i in scale_column_idx_set:
  121. if column_min_value[i] < self.feat_lower[i]:
  122. column_min_value[i] = self.feat_lower[i]
  123. if column_max_value[i] < self.feat_lower[i]:
  124. column_max_value[i] = self.feat_lower[i]
  125. else:
  126. for i in range(data_shape):
  127. if i in scale_column_idx_set:
  128. if column_min_value[i] < self.feat_lower:
  129. column_min_value[i] = self.feat_lower
  130. if column_max_value[i] < self.feat_lower:
  131. column_max_value[i] = self.feat_lower
  132. return column_min_value, column_max_value
  133. def __get_min_max_value_by_cap(self, data):
  134. data_shape = self._get_data_shape(data)
  135. self.summary_obj = MultivariateStatisticalSummary(data, -1)
  136. header = data.schema.get("header")
  137. if self.feat_upper is None:
  138. self.feat_upper = 1.0
  139. if self.feat_lower is None:
  140. self.feat_lower = 0
  141. if self.feat_upper < self.feat_lower:
  142. raise ValueError("feat_upper should not less than feat_lower")
  143. column_min_value = self.summary_obj.get_quantile_point(self.feat_lower)
  144. column_min_value = [column_min_value[key] for key in header]
  145. column_max_value = self.summary_obj.get_quantile_point(self.feat_upper)
  146. column_max_value = [column_max_value[key] for key in header]
  147. self.__check_equal(data_shape, len(column_min_value))
  148. self.__check_equal(data_shape, len(column_max_value))
  149. return column_min_value, column_max_value
  150. def _get_min_max_value(self, data):
  151. """
  152. Get each column minimum and maximum
  153. """
  154. if self.mode == consts.NORMAL:
  155. return self.__get_min_max_value_by_normal(data)
  156. elif self.mode == consts.CAP:
  157. return self.__get_min_max_value_by_cap(data)
  158. else:
  159. raise ValueError("unknown mode of {}".format(self.mode))
  160. def set_column_range(self, upper, lower):
  161. self.column_max_value = upper
  162. self.column_min_value = lower
  163. @staticmethod
  164. def reset_feature_range(data, column_max_value, column_min_value, scale_column_idx):
  165. _data = copy.deepcopy(data)
  166. for i in scale_column_idx:
  167. value = _data.features[i]
  168. if value > column_max_value[i]:
  169. _data.features[i] = column_max_value[i]
  170. elif value < column_min_value[i]:
  171. _data.features[i] = column_min_value[i]
  172. return _data
  173. def fit_feature_range(self, data):
  174. if self.feat_lower is not None or self.feat_upper is not None:
  175. LOGGER.info("Need fit feature range")
  176. if not isinstance(self.column_min_value, Iterable) or not isinstance(self.column_max_value, Iterable):
  177. LOGGER.info(
  178. "column_min_value type is:{}, column_min_value type is:{} , should be iterable, start to get new one".format(
  179. type(
  180. self.column_min_value), type(
  181. self.column_max_value)))
  182. self.column_min_value, self.column_max_value = self._get_min_max_value(data)
  183. if not self.scale_column_idx:
  184. self.scale_column_idx = self._get_scale_column_idx(data)
  185. LOGGER.info("scale_column_idx is None, start to get new one, new scale_column_idx:{}".format(
  186. self.scale_column_idx))
  187. f = functools.partial(self.reset_feature_range, column_max_value=self.column_max_value,
  188. column_min_value=self.column_min_value, scale_column_idx=self.scale_column_idx)
  189. fit_data = data.mapValues(f)
  190. fit_data.schema = data.schema
  191. return fit_data
  192. else:
  193. LOGGER.info("feat_lower is None and feat_upper is None, do not need to fit feature range!")
  194. return data
  195. def get_model_summary(self):
  196. cols_info = self._get_param().col_scale_param
  197. return {
  198. col_name: {
  199. "column_upper": col.column_upper,
  200. "column_lower": col.column_lower,
  201. "mean": col.mean,
  202. "std": col.std} for col_name,
  203. col in cols_info.items()}
  204. def export_model(self, need_run):
  205. meta_obj = self._get_meta(need_run)
  206. param_obj = self._get_param()
  207. result = {
  208. self.model_meta_name: meta_obj,
  209. self.model_param_name: param_obj
  210. }
  211. return result
  212. def fit(self, data):
  213. pass
  214. def transform(self, data):
  215. pass
  216. def load_model(self, name, namespace):
  217. pass
  218. def save_model(self, name, namespace):
  219. pass
  220. def _get_param(self):
  221. pass
  222. def _get_meta(self, need_run):
  223. pass