123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from fate_arch.abc import StorageTableMetaABC, AddressABC
- from fate_arch.common.address import MysqlAddress, HiveAddress
- from fate_arch.common.data_utils import default_output_fs_path
- from fate_arch.computing import ComputingEngine
- from fate_arch.storage import StorageEngine, StorageTableMeta
- from fate_flow.entity.types import InputSearchType
- from fate_arch import storage
- def get_header_schema(header_line, id_delimiter, extend_sid=False):
- header_source_item = header_line.split(id_delimiter)
- if extend_sid:
- header = id_delimiter.join(header_source_item).strip()
- sid = get_extend_id_name()
- else:
- header = id_delimiter.join(header_source_item[1:]).strip()
- sid = header_source_item[0].strip()
- return {'header': header, 'sid': sid}
- def get_extend_id_name():
- return "extend_sid"
- def get_sid_data_line(values, id_delimiter, fate_uuid, line_index, **kwargs):
- return line_extend_uuid(fate_uuid, line_index), list_to_str(values, id_delimiter=id_delimiter)
- def line_extend_uuid(fate_uuid, line_index):
- return fate_uuid + str(line_index)
- def get_auto_increasing_sid_data_line(values, id_delimiter, line_index, **kwargs):
- return line_index, list_to_str(values, id_delimiter=id_delimiter)
- def get_data_line(values, id_delimiter, **kwargs):
- return values[0], list_to_str(values[1:], id_delimiter=id_delimiter)
- def list_to_str(input_list, id_delimiter):
- return id_delimiter.join(list(map(str, input_list)))
- def convert_output(
- input_name,
- input_namespace,
- output_name,
- output_namespace,
- computing_engine: ComputingEngine = ComputingEngine.EGGROLL,
- output_storage_address={},
- ) -> (StorageTableMetaABC, AddressABC, StorageEngine):
- input_table_meta = StorageTableMeta(name=input_name, namespace=input_namespace)
- if not input_table_meta:
- raise RuntimeError(
- f"can not found table name: {input_name} namespace: {input_namespace}"
- )
- address_dict = output_storage_address.copy()
- if input_table_meta.get_engine() in [StorageEngine.PATH]:
- from fate_arch.storage import PathStoreType
- address_dict["name"] = output_name
- address_dict["namespace"] = output_namespace
- address_dict["storage_type"] = PathStoreType.PICTURE
- address_dict["path"] = input_table_meta.get_address().path
- output_table_address = StorageTableMeta.create_address(
- storage_engine=StorageEngine.PATH, address_dict=address_dict
- )
- output_table_engine = StorageEngine.PATH
- elif computing_engine == ComputingEngine.STANDALONE:
- from fate_arch.storage import StandaloneStoreType
- address_dict["name"] = output_name
- address_dict["namespace"] = output_namespace
- address_dict["storage_type"] = StandaloneStoreType.ROLLPAIR_LMDB
- output_table_address = StorageTableMeta.create_address(
- storage_engine=StorageEngine.STANDALONE, address_dict=address_dict
- )
- output_table_engine = StorageEngine.STANDALONE
- elif computing_engine == ComputingEngine.EGGROLL:
- from fate_arch.storage import EggRollStoreType
- address_dict["name"] = output_name
- address_dict["namespace"] = output_namespace
- address_dict["storage_type"] = EggRollStoreType.ROLLPAIR_LMDB
- output_table_address = StorageTableMeta.create_address(
- storage_engine=StorageEngine.EGGROLL, address_dict=address_dict
- )
- output_table_engine = StorageEngine.EGGROLL
- elif computing_engine == ComputingEngine.SPARK:
- if input_table_meta.get_engine() == StorageEngine.HIVE:
- output_table_address = input_table_meta.get_address()
- output_table_address.name = output_name
- output_table_engine = input_table_meta.get_engine()
- elif input_table_meta.get_engine() == StorageEngine.LOCALFS:
- output_table_address = input_table_meta.get_address()
- output_table_address.path = default_output_fs_path(
- name=output_name,
- namespace=output_namespace,
- storage_engine=StorageEngine.LOCALFS
- )
- output_table_engine = input_table_meta.get_engine()
- else:
- address_dict["path"] = default_output_fs_path(
- name=output_name,
- namespace=output_namespace,
- prefix=address_dict.get("path_prefix"),
- storage_engine=StorageEngine.HDFS
- )
- output_table_address = StorageTableMeta.create_address(
- storage_engine=StorageEngine.HDFS, address_dict=address_dict
- )
- output_table_engine = StorageEngine.HDFS
- elif computing_engine == ComputingEngine.LINKIS_SPARK:
- output_table_address = input_table_meta.get_address()
- output_table_address.name = output_name
- output_table_engine = input_table_meta.get_engine()
- else:
- raise RuntimeError(f"can not support computing engine {computing_engine}")
- return input_table_meta, output_table_address, output_table_engine
- def get_input_data_min_partitions(input_data, role, party_id):
- min_partition = None
- if role != 'arbiter':
- for data_type, data_location in input_data[role][party_id].items():
- table_info = {'name': data_location.split('.')[1], 'namespace': data_location.split('.')[0]}
- table_meta = storage.StorageTableMeta(name=table_info['name'], namespace=table_info['namespace'])
- if table_meta:
- table_partition = table_meta.get_partitions()
- if not min_partition or min_partition > table_partition:
- min_partition = table_partition
- return min_partition
- def get_input_search_type(parameters):
- if "name" in parameters and "namespace" in parameters:
- return InputSearchType.TABLE_INFO
- elif "job_id" in parameters and "component_name" in parameters and "data_name" in parameters:
- return InputSearchType.JOB_COMPONENT_OUTPUT
- else:
- return InputSearchType.UNKNOWN
- def address_filter(address):
- if isinstance(address, MysqlAddress):
- address.passwd = None
- if isinstance(address, HiveAddress):
- address.password = None
- return address.__dict__
|