#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import copy import functools import numpy as np DEFAULT_LABEL_NAME = "label" DEFAULT_MATCH_ID_PREFIX = "match_id" SVMLIGHT_COLUMN_PREFIX = "x" DEFAULT_SID_NAME = "sid" DELIMITER = "," class DataFormatPreProcess(object): @staticmethod def get_feature_offset(meta): """ works for sparse/svmlight/tag value data """ with_label = meta.get("with_label", False) with_match_id = meta.get("with_match_id", False) id_range = meta.get("id_range", 0) if with_match_id: if not id_range: id_range = 1 offset = id_range if with_label: offset += 1 return offset @staticmethod def agg_partition_tags(kvs, delimiter=",", offset=0, tag_with_value=True, tag_value_delimiter=":"): tag_set = set() for _, value in kvs: cols = value.split(delimiter, -1)[offset:] if tag_with_value: tag_set |= set([col.split(tag_value_delimiter, -1)[0] for col in cols]) else: tag_set |= set(cols) return tag_set @staticmethod def get_tag_list(data, schema): if "meta" not in schema: raise ValueError("Meta not in schema") meta = schema["meta"] if meta["input_format"] != "tag": raise ValueError("Input DataFormat Should Be Tag Or Tag Value") delimiter = meta["delimiter"] tag_with_value = meta["tag_with_value"] if not isinstance(tag_with_value, bool): raise ValueError(f"tag with value should be bool, bug {tag_with_value} find") tag_value_delimiter = meta["tag_value_delimiter"] offset = DataFormatPreProcess.get_feature_offset(meta) agg_func = functools.partial(DataFormatPreProcess.agg_partition_tags, delimiter=delimiter, offset=offset, tag_with_value=tag_with_value, tag_value_delimiter=tag_value_delimiter) agg_tags = data.applyPartitions(agg_func).reduce(lambda tag_set1, tag_set2: tag_set1 | tag_set2) return sorted(agg_tags) @staticmethod def get_lib_svm_dim(data, schema): if "meta" not in schema: raise ValueError("Meta not in schema") meta = schema["meta"] if "input_format" == ["sparse", "svmlight"]: raise ValueError("Input DataFormat Should Be SVMLight") delimiter = meta.get("delimiter", " ") offset = DataFormatPreProcess.get_feature_offset(meta) max_dim = data.\ mapValues( lambda value: max([int(fid_value.split(":", -1)[0]) for fid_value in value.split(delimiter, -1)[offset:]])).\ reduce(lambda x, y: max(x, y)) return max_dim @staticmethod def generate_header(data, schema): if not schema.get('meta'): raise ValueError("Meta not in schema") meta = schema["meta"] generated_header = dict(original_index_info=dict(), meta=meta) input_format = meta.get("input_format") delimiter = meta.get("delimiter", ",") if not input_format: raise ValueError("InputFormat should be configured.") if input_format == "dense": if "header" not in schema: raise ValueError("Dense input data must have schema") header = schema["header"].strip().split(delimiter, -1) header = list(map(lambda col: col.strip(), header)) header_index_mapping = dict(zip(header, range(len(header)))) with_label = meta.get("with_label", False) if not isinstance(with_label, bool): raise ValueError("with_label should be True or False") id_list = meta.get("id_list", []) if not isinstance(id_list, (type(None), list)): raise ValueError("id_list should be list type or None") with_match_id = meta.get("with_match_id", False) filter_ids = set() if with_match_id: if not id_list: match_id_name = header[0] match_id_index = [0] filter_ids.add(0) else: match_id_name = [] match_id_index = [] for _id in id_list: if _id in header_index_mapping: match_id_name.append(_id) match_id_index.append(header_index_mapping[_id]) filter_ids.add(match_id_index[-1]) else: raise ValueError(f"Can not find {_id} in id_list in data's header") generated_header["match_id_name"] = match_id_name generated_header["original_index_info"]["match_id_index"] = match_id_index if with_label: label_name = meta["label_name"] label_index = header_index_mapping[label_name] generated_header["label_name"] = label_name generated_header["original_index_info"]["label_index"] = label_index filter_ids.add(label_index) header_ids = list(filter(lambda ids: ids not in filter_ids, range(len(header)))) generated_header["original_index_info"]["header_index"] = header_ids generated_header["header"] = np.array(header)[header_ids].tolist() else: if input_format == "tag": sorted_tag_list = DataFormatPreProcess.get_tag_list(data, schema) generated_header["header"] = sorted_tag_list elif input_format in ["sparse", "svmlight"]: max_dim = DataFormatPreProcess.get_lib_svm_dim(data, schema) generated_header["header"] = [SVMLIGHT_COLUMN_PREFIX + str(i) for i in range(max_dim + 1)] else: raise NotImplementedError(f"InputFormat {input_format} is not implemented") with_label = meta.get("with_label", False) with_match_id = meta.get("with_match_id", False) id_range = meta.get("id_range", 0) if id_range and not with_match_id: raise ValueError(f"id_range {id_range} != 0, with_match_id should be true") if with_match_id: if not id_range: id_range = 1 if id_range == 1: generated_header["match_id_name"] = DEFAULT_MATCH_ID_PREFIX else: generated_header["match_id_name"] = [DEFAULT_MATCH_ID_PREFIX + str(i) for i in range(id_range)] if with_label: generated_header["label_name"] = DEFAULT_LABEL_NAME if id_range: generated_header["meta"]["id_range"] = id_range generated_header["is_display"] = False sid = schema.get("sid") if sid is None or sid == "": sid = DEFAULT_SID_NAME generated_header["sid"] = sid.strip() return generated_header @staticmethod def reconstruct_header(schema): original_index_info = schema.get("original_index_info") if not original_index_info: return schema["header"] header_index_mapping = dict() if "header_index" in original_index_info and original_index_info["header_index"]: for idx, col_name in zip(original_index_info["header_index"], schema["header"]): header_index_mapping[idx] = col_name if original_index_info.get("match_id_index") is not None: match_id_name = schema["match_id_name"] match_id_index = original_index_info["match_id_index"] if isinstance(match_id_name, str): header_index_mapping[match_id_index[0]] = match_id_name else: for idx, col_name in zip(match_id_index, match_id_name): header_index_mapping[idx] = col_name if original_index_info.get("label_index") is not None: header_index_mapping[original_index_info["label_index"]] = schema["label_name"] original_header = [None] * len(header_index_mapping) for idx, col_name in header_index_mapping.items(): original_header[idx] = col_name return original_header @staticmethod def extend_header(schema, columns): schema = copy.deepcopy(schema) original_index_info = schema.get("original_index_info") columns = list(map(lambda column: column.strip(), columns)) header = schema["header"] if isinstance(header, list): header.extend(columns) schema["header"] = header if original_index_info and "header_index" in original_index_info: header_index = original_index_info["header_index"] if header_index: pre_max_col_idx = max(header_index) else: pre_max_col_idx = -1 if original_index_info.get("label_index") is not None: pre_max_col_idx = max(original_index_info["label_index"], pre_max_col_idx) if original_index_info.get("match_id_index") is not None: pre_max_col_idx = max(max(original_index_info["match_id_index"]), pre_max_col_idx) append_header_index = [i + pre_max_col_idx + 1 for i in range(len(columns))] schema["original_index_info"]["header_index"] = header_index + append_header_index else: if len(header) == 0: new_header = DELIMITER.join(columns) else: new_header = DELIMITER.join(header.split(DELIMITER, -1) + columns) schema["header"] = new_header if schema.get("sid") is not None: schema["sid"] = schema["sid"].strip() return schema @staticmethod def clean_header(schema): schema = copy.deepcopy(schema) header = schema["header"] if "label_name" in schema: del schema["label_name"] if "anonymous_header" in schema: del schema["anonymous_header"] if "anonymous_label" in schema: del schema["anonymous_label"] if isinstance(header, list): schema["header"] = [] original_index_info = schema.get("original_index_info") if original_index_info: del schema["original_index_info"] if "match_id_name" in schema: del schema["match_id_name"] if "match_id_index" in schema: del schema["match_id_index"] else: schema["header"] = "" return schema @staticmethod def recover_schema(schema): if not schema.get('meta'): raise ValueError("Meta not in schema, can not recover meta") recovery_schema = copy.deepcopy(schema) meta = schema["meta"] input_format = meta.get("input_format", "dense") if input_format == "dense": """schema has not been processed yet""" if "original_index_info" not in schema: return recovery_schema header_list = DataFormatPreProcess.reconstruct_header(schema) del recovery_schema["original_index_info"] delimiter = schema.get("delimiter", ",") header = "" if not header_list else delimiter.join(header_list) recovery_schema["header"] = header if "label_name" in recovery_schema: del recovery_schema["label_name"] if meta.get("with_match_id"): del recovery_schema["match_id_name"] else: recovery_schema["header"] = "" if "label_name" in recovery_schema: del recovery_schema["label_name"] if meta.get("id_range"): recovery_schema["meta"]["id_range"] = 0 if meta.get("with_label"): del recovery_schema["meta"]["label_name"] del recovery_schema["is_display"] if meta.get("with_match_id"): del recovery_schema["match_id_name"] if "anonymous_header" in schema: del recovery_schema["anonymous_header"] if "anonymous_label" in schema: del recovery_schema["anonymous_label"] return recovery_schema