| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- #
- # Copyright 2019 The Eggroll Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from collections import Iterable
- from fate_arch._standalone import Session
- from fate_arch.abc import AddressABC, CSessionABC
- from fate_arch.common.base_utils import fate_uuid
- from fate_arch.common.log import getLogger
- from fate_arch.computing.standalone._table import Table
- LOGGER = getLogger()
- class CSession(CSessionABC):
- def __init__(self, session_id: str, options=None):
- if options is not None:
- max_workers = options.get("task_cores", None)
- self._session = Session(session_id, max_workers=max_workers)
- def get_standalone_session(self):
- return self._session
- @property
- def session_id(self):
- return self._session.session_id
- def load(self, address: AddressABC, partitions: int, schema: dict, **kwargs):
- from fate_arch.common.address import StandaloneAddress
- from fate_arch.storage import StandaloneStoreType
- if isinstance(address, StandaloneAddress):
- raw_table = self._session.load(address.name, address.namespace)
- if address.storage_type != StandaloneStoreType.ROLLPAIR_IN_MEMORY:
- raw_table = raw_table.save_as(
- name=f"{address.name}_{fate_uuid()}",
- namespace=address.namespace,
- partition=partitions,
- need_cleanup=True,
- )
- table = Table(raw_table)
- table.schema = schema
- return table
- from fate_arch.common.address import PathAddress
- if isinstance(address, PathAddress):
- from fate_arch.computing.non_distributed import LocalData
- from fate_arch.computing import ComputingEngine
- return LocalData(address.path, engine=ComputingEngine.STANDALONE)
- raise NotImplementedError(
- f"address type {type(address)} not supported with standalone backend"
- )
- def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs):
- table = self._session.parallelize(
- data=data, partition=partition, include_key=include_key, **kwargs
- )
- return Table(table)
- def cleanup(self, name, namespace):
- return self._session.cleanup(name=name, namespace=namespace)
- def stop(self):
- return self._session.stop()
- def kill(self):
- return self._session.kill()
- def destroy(self):
- try:
- LOGGER.info(f"clean table namespace {self.session_id}")
- self.cleanup(namespace=self.session_id, name="*")
- except Exception as e:
- LOGGER.warning(f"no found table namespace {self.session_id}")
- try:
- self.stop()
- except Exception as e:
- LOGGER.warning(f"stop storage session {self.session_id} failed, try to kill", e)
- self.kill()
|