123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- import functools
- import copy
- from federatedml.statistic import data_overview
- from federatedml.util import LOGGER
- def empty_table_detection(data_instances):
- num_data = data_instances.count()
- if num_data == 0:
- raise ValueError(f"Count of data_instance is 0: {data_instances}")
- def empty_feature_detection(data_instances):
- is_empty_feature = data_overview.is_empty_feature(data_instances)
- if is_empty_feature:
- raise ValueError(f"Number of features of Table is 0: {data_instances}")
- def column_gathering(iterable, ):
- non_empty_columns = set()
- for k, v in iterable:
- features = v.features
- if isinstance(features.dtype, (np.int, np.int64, np.int32, np.float, np.float32, np.float64, np.long)):
- non_empty_columns.update(np.where(~np.isnan(features))[0])
- else:
- for col_idx, col_v in enumerate(features):
- if col_v != col_v or col_v == "":
- continue
- else:
- non_empty_columns.add(col_idx)
- return non_empty_columns
- def merge_column_sets(v1: set, v2: set):
- v1_copy = copy.deepcopy(v1)
- v2_copy = copy.deepcopy(v2)
- v1_copy.update(v2_copy)
- return v1_copy
- def empty_column_detection(data_instance):
- contains_empty_columns = False
- lost_feat = []
- is_sparse = data_overview.is_sparse_data(data_instance)
- if is_sparse:
- raise ValueError('sparse format empty column detection is not supported for now')
- map_func = functools.partial(column_gathering, )
- map_rs = data_instance.applyPartitions(map_func)
- reduce_rs = map_rs.reduce(merge_column_sets)
- # transform col index to col name
- reduce_rs = np.array(data_instance.schema['header'])[list(reduce_rs)]
- reduce_rs = set(reduce_rs)
- if reduce_rs != set(data_instance.schema['header']):
- lost_feat = list(set(data_instance.schema['header']).difference(reduce_rs))
- contains_empty_columns = True
- if contains_empty_columns:
- raise ValueError('column(s) {} contain(s) no values'.format(lost_feat))
- def check_legal_schema(schema):
- # check for repeated header & illegal/non-printable chars except for space
- # allow non-ascii chars
- LOGGER.debug(f"schema is {schema}")
- if schema is None:
- return
- header = schema.get("header", None)
- LOGGER.debug(f"header is {header}")
- if header is not None:
- for col_name in header:
- if not col_name.isprintable():
- raise ValueError(f"non-printable char found in header column {col_name}, please check.")
- header_set = set(header)
- if len(header_set) != len(header):
- raise ValueError(f"data header contains repeated names, please check.")
- sid_name = schema.get("sid", None)
- LOGGER.debug(f"sid is {sid_name}")
- if sid_name is not None and not sid_name.isprintable():
- raise ValueError(f"non-printable char found in sid_name {sid_name}, please check.")
- label_name = schema.get("label_name", None)
- LOGGER.debug(f"label_name is {label_name}")
- if label_name is not None and not label_name.isprintable():
- raise ValueError(f"non-printable char found in label_name {label_name}, please check.")
|