| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Iterable
- from fate_arch.abc import AddressABC
- from fate_arch.abc import CSessionABC
- from fate_arch.common.address import LocalFSAddress
- from fate_arch.computing.spark._table import from_hdfs, from_rdd, from_hive, from_localfs
- from fate_arch.common import log
- LOGGER = log.getLogger()
- class CSession(CSessionABC):
- """
- manage RDDTable
- """
- def __init__(self, session_id):
- self._session_id = session_id
- def load(self, address: AddressABC, partitions, schema, **kwargs):
- from fate_arch.common.address import HDFSAddress
- if isinstance(address, HDFSAddress):
- table = from_hdfs(
- paths=f"{address.name_node}/{address.path}",
- partitions=partitions,
- in_serialized=kwargs.get(
- "in_serialized",
- True),
- id_delimiter=kwargs.get(
- "id_delimiter",
- ','))
- table.schema = schema
- return table
- from fate_arch.common.address import PathAddress
- if isinstance(address, PathAddress):
- from fate_arch.computing.non_distributed import LocalData
- from fate_arch.computing import ComputingEngine
- return LocalData(address.path, engine=ComputingEngine.SPARK)
- from fate_arch.common.address import HiveAddress, LinkisHiveAddress
- if isinstance(address, (HiveAddress, LinkisHiveAddress)):
- table = from_hive(
- tb_name=address.name,
- db_name=address.database,
- partitions=partitions,
- )
- table.schema = schema
- return table
- if isinstance(address, LocalFSAddress):
- table = from_localfs(
- paths=address.path, partitions=partitions, in_serialized=kwargs.get(
- "in_serialized", True), id_delimiter=kwargs.get(
- "id_delimiter", ','))
- table.schema = schema
- return table
- raise NotImplementedError(
- f"address type {type(address)} not supported with spark backend"
- )
- def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs):
- # noinspection PyPackageRequirements
- from pyspark import SparkContext
- _iter = data if include_key else enumerate(data)
- rdd = SparkContext.getOrCreate().parallelize(_iter, partition)
- return from_rdd(rdd)
- @property
- def session_id(self):
- return self._session_id
- def cleanup(self, name, namespace):
- pass
- def stop(self):
- pass
- def kill(self):
- pass
- def destroy(self):
- pass
|