data_statistics.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. from federatedml.feature.fate_element_type import NoneType
  19. from federatedml.model_base import ModelBase
  20. from federatedml.param.statistics_param import StatisticsParam
  21. from federatedml.protobuf.generated import statistic_meta_pb2, statistic_param_pb2
  22. from federatedml.statistic.data_overview import get_header
  23. from federatedml.statistic.statics import MultivariateStatisticalSummary
  24. from federatedml.util import LOGGER
  25. from federatedml.util import abnormal_detection
  26. from federatedml.util import consts
  27. MODEL_PARAM_NAME = 'StatisticParam'
  28. MODEL_META_NAME = 'StatisticMeta'
  29. SYSTEM_ABNORMAL_VALUES = [None, np.nan, NoneType]
  30. class StatisticInnerParam(object):
  31. def __init__(self):
  32. self.col_name_maps = {}
  33. self.header = []
  34. self.static_indices = []
  35. self.static_indices_set = set()
  36. self.static_names = []
  37. def set_header(self, header):
  38. self.header = header
  39. for idx, col_name in enumerate(self.header):
  40. self.col_name_maps[col_name] = idx
  41. def set_static_all(self):
  42. self.static_indices = [i for i in range(len(self.header))]
  43. self.static_indices_set = set(self.static_indices)
  44. self.static_names = self.header
  45. def add_static_indices(self, static_indices):
  46. if static_indices is None:
  47. return
  48. for idx in static_indices:
  49. if idx >= len(self.header):
  50. LOGGER.warning("Adding indices that out of header's bound")
  51. continue
  52. if idx not in self.static_indices_set:
  53. self.static_indices_set.add(idx)
  54. self.static_indices.append(idx)
  55. self.static_names.append(self.header[idx])
  56. def add_static_names(self, static_names):
  57. if static_names is None:
  58. return
  59. for col_name in static_names:
  60. idx = self.col_name_maps.get(col_name)
  61. if idx is None:
  62. LOGGER.warning(f"Adding col_name: {col_name} that is not exist in header")
  63. continue
  64. if idx not in self.static_indices_set:
  65. self.static_indices_set.add(idx)
  66. self.static_indices.append(idx)
  67. self.static_names.append(self.header[idx])
  68. class DataStatistics(ModelBase):
  69. def __init__(self):
  70. super().__init__()
  71. self.model_param = StatisticsParam()
  72. self.inner_param = None
  73. self.schema = None
  74. self.statistic_obj: MultivariateStatisticalSummary = None
  75. self._result_dict = {}
  76. self._numeric_statics = []
  77. self._quantile_statics = []
  78. self.feature_value_pb = []
  79. def _init_model(self, model_param):
  80. self.model_param = model_param
  81. for stat_name in self.model_param.statistics:
  82. if stat_name in self.model_param.LEGAL_STAT:
  83. self._numeric_statics.append(stat_name)
  84. else:
  85. self._quantile_statics.append(stat_name)
  86. def _init_param(self, data_instances):
  87. if self.schema is None or len(self.schema) == 0:
  88. self.schema = data_instances.schema
  89. if self.inner_param is not None:
  90. return
  91. self.inner_param = StatisticInnerParam()
  92. # self.schema = data_instances.schema
  93. LOGGER.debug("In _init_params, schema is : {}".format(self.schema))
  94. header = get_header(data_instances)
  95. self.inner_param.set_header(header)
  96. if self.model_param.column_indexes == -1:
  97. self.inner_param.set_static_all()
  98. else:
  99. self.inner_param.add_static_indices(self.model_param.column_indexes)
  100. self.inner_param.add_static_names(self.model_param.column_names)
  101. LOGGER.debug(f"column_indexes: {self.model_param.column_indexes}, inner_param"
  102. f" static_indices: {self.inner_param.static_indices}")
  103. return self
  104. @staticmethod
  105. def _merge_abnormal_list(abnormal_list):
  106. if abnormal_list is None:
  107. return SYSTEM_ABNORMAL_VALUES
  108. return abnormal_list + SYSTEM_ABNORMAL_VALUES
  109. def fit(self, data_instances):
  110. self._init_param(data_instances)
  111. self._abnormal_detection(data_instances)
  112. if consts.KURTOSIS in self.model_param.statistics:
  113. stat_order = 4
  114. elif consts.SKEWNESS in self.model_param.statistics:
  115. stat_order = 3
  116. else:
  117. stat_order = 2
  118. abnormal_list = self._merge_abnormal_list(self.model_param.abnormal_list)
  119. self.statistic_obj = MultivariateStatisticalSummary(data_instances,
  120. cols_index=self.inner_param.static_indices,
  121. abnormal_list=abnormal_list,
  122. error=self.model_param.quantile_error,
  123. stat_order=stat_order,
  124. bias=self.model_param.bias)
  125. results = None
  126. for stat_name in self._numeric_statics:
  127. stat_res = self.statistic_obj.get_statics(stat_name)
  128. LOGGER.debug(f"state_name: {stat_name}, stat_res: {stat_res}")
  129. self.feature_value_pb.append(self._convert_pb(stat_res, stat_name))
  130. if results is None:
  131. results = {k: {stat_name: v} for k, v in stat_res.items()}
  132. else:
  133. for k, v in results.items():
  134. results[k] = dict(**v, **{stat_name: stat_res[k]})
  135. for query_point in self._quantile_statics:
  136. q = float(query_point[:-1]) / 100
  137. res = self.statistic_obj.get_quantile_point(q)
  138. self.feature_value_pb.append(self._convert_pb(res, query_point))
  139. if results is None:
  140. results = res
  141. else:
  142. for k, v in res.items():
  143. results[k][query_point] = v
  144. for k, v in results.items():
  145. # new_dict = {}
  146. # for stat_name, value in v.items():
  147. # LOGGER.debug(f"stat_name: {stat_name}, value: {value}, type: {type(value)}")
  148. self.add_summary(k, v)
  149. LOGGER.debug(f"Before return, summary: {self.summary()}")
  150. def _convert_pb(self, stat_res, stat_name):
  151. values = [stat_res[col_name] for col_name in self.inner_param.static_names]
  152. return statistic_param_pb2.StatisticSingleFeatureValue(
  153. values=values,
  154. col_names=self.inner_param.static_names,
  155. value_name=stat_name
  156. )
  157. def export_model(self):
  158. if self.model_output is not None:
  159. return self.model_output
  160. meta_obj = self._get_meta()
  161. param_obj = self._get_param()
  162. result = {
  163. MODEL_META_NAME: meta_obj,
  164. MODEL_PARAM_NAME: param_obj
  165. }
  166. self.model_output = result
  167. return result
  168. def _get_meta(self):
  169. return statistic_meta_pb2.StatisticMeta(
  170. statistics=self.model_param.statistics,
  171. static_columns=self.inner_param.static_names,
  172. quantile_error=self.model_param.quantile_error,
  173. need_run=self.model_param.need_run
  174. )
  175. def _get_param(self):
  176. all_result = statistic_param_pb2.StatisticOnePartyResult(
  177. results=self.feature_value_pb
  178. )
  179. return statistic_param_pb2.ModelParam(
  180. self_values=all_result,
  181. model_name=consts.STATISTIC_MODEL
  182. )
  183. def _abnormal_detection(self, data_instances):
  184. """
  185. Make sure input data_instances is valid.
  186. """
  187. abnormal_detection.empty_table_detection(data_instances)
  188. abnormal_detection.empty_feature_detection(data_instances)
  189. self.check_schema_content(data_instances.schema)