abnormal_detection.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import numpy as np
  18. import functools
  19. import copy
  20. from federatedml.statistic import data_overview
  21. from federatedml.util import LOGGER
  22. def empty_table_detection(data_instances):
  23. num_data = data_instances.count()
  24. if num_data == 0:
  25. raise ValueError(f"Count of data_instance is 0: {data_instances}")
  26. def empty_feature_detection(data_instances):
  27. is_empty_feature = data_overview.is_empty_feature(data_instances)
  28. if is_empty_feature:
  29. raise ValueError(f"Number of features of Table is 0: {data_instances}")
  30. def column_gathering(iterable, ):
  31. non_empty_columns = set()
  32. for k, v in iterable:
  33. features = v.features
  34. if isinstance(features.dtype, (np.int, np.int64, np.int32, np.float, np.float32, np.float64, np.long)):
  35. non_empty_columns.update(np.where(~np.isnan(features))[0])
  36. else:
  37. for col_idx, col_v in enumerate(features):
  38. if col_v != col_v or col_v == "":
  39. continue
  40. else:
  41. non_empty_columns.add(col_idx)
  42. return non_empty_columns
  43. def merge_column_sets(v1: set, v2: set):
  44. v1_copy = copy.deepcopy(v1)
  45. v2_copy = copy.deepcopy(v2)
  46. v1_copy.update(v2_copy)
  47. return v1_copy
  48. def empty_column_detection(data_instance):
  49. contains_empty_columns = False
  50. lost_feat = []
  51. is_sparse = data_overview.is_sparse_data(data_instance)
  52. if is_sparse:
  53. raise ValueError('sparse format empty column detection is not supported for now')
  54. map_func = functools.partial(column_gathering, )
  55. map_rs = data_instance.applyPartitions(map_func)
  56. reduce_rs = map_rs.reduce(merge_column_sets)
  57. # transform col index to col name
  58. reduce_rs = np.array(data_instance.schema['header'])[list(reduce_rs)]
  59. reduce_rs = set(reduce_rs)
  60. if reduce_rs != set(data_instance.schema['header']):
  61. lost_feat = list(set(data_instance.schema['header']).difference(reduce_rs))
  62. contains_empty_columns = True
  63. if contains_empty_columns:
  64. raise ValueError('column(s) {} contain(s) no values'.format(lost_feat))
  65. def check_legal_schema(schema):
  66. # check for repeated header & illegal/non-printable chars except for space
  67. # allow non-ascii chars
  68. LOGGER.debug(f"schema is {schema}")
  69. if schema is None:
  70. return
  71. header = schema.get("header", None)
  72. LOGGER.debug(f"header is {header}")
  73. if header is not None:
  74. for col_name in header:
  75. if not col_name.isprintable():
  76. raise ValueError(f"non-printable char found in header column {col_name}, please check.")
  77. header_set = set(header)
  78. if len(header_set) != len(header):
  79. raise ValueError(f"data header contains repeated names, please check.")
  80. sid_name = schema.get("sid", None)
  81. LOGGER.debug(f"sid is {sid_name}")
  82. if sid_name is not None and not sid_name.isprintable():
  83. raise ValueError(f"non-printable char found in sid_name {sid_name}, please check.")
  84. label_name = schema.get("label_name", None)
  85. LOGGER.debug(f"label_name is {label_name}")
  86. if label_name is not None and not label_name.isprintable():
  87. raise ValueError(f"non-printable char found in label_name {label_name}, please check.")