123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import typing
- from uuid import uuid1
- from fate_arch import session, storage
- from fate_arch.abc import CTableABC
- from fate_arch.common import DTable
- from fate_arch.common.base_utils import current_timestamp
- from fate_flow.db.db_models import DB, CacheRecord
- from fate_flow.entity import DataCache
- class CacheManager:
- @classmethod
- def persistent(cls, cache_name: str, cache_data: typing.Dict[str, CTableABC], cache_meta: dict, output_namespace: str,
- output_name: str, output_storage_engine: str, output_storage_address: dict,
- token=None) -> DataCache:
- cache = DataCache(name=cache_name, meta=cache_meta)
- for name, table in cache_data.items():
- table_meta = session.Session.persistent(computing_table=table,
- namespace=output_namespace,
- name=f"{output_name}_{name}",
- schema=None,
- engine=output_storage_engine,
- engine_address=output_storage_address,
- token=token)
- cache.data[name] = DTable(namespace=table_meta.namespace, name=table_meta.name,
- partitions=table_meta.partitions)
- return cache
- @classmethod
- def load(cls, cache: DataCache) -> typing.Tuple[typing.Dict[str, CTableABC], dict]:
- cache_data = {}
- for name, table in cache.data.items():
- storage_table_meta = storage.StorageTableMeta(name=table.name, namespace=table.namespace)
- computing_table = session.get_computing_session().load(
- storage_table_meta.get_address(),
- schema=storage_table_meta.get_schema(),
- partitions=table.partitions)
- cache_data[name] = computing_table
- return cache_data, cache.meta
- @classmethod
- @DB.connection_context()
- def record(cls, cache: DataCache, job_id: str = None, role: str = None, party_id: int = None, component_name: str = None, task_id: str = None, task_version: int = None,
- cache_name: str = None):
- for attr in {"job_id", "component_name", "task_id", "task_version"}:
- if getattr(cache, attr) is None and locals().get(attr) is not None:
- setattr(cache, attr, locals().get(attr))
- record = CacheRecord()
- record.f_create_time = current_timestamp()
- record.f_cache_key = uuid1().hex
- cache.key = record.f_cache_key
- record.f_cache = cache
- record.f_job_id = job_id
- record.f_role = role
- record.f_party_id = party_id
- record.f_component_name = component_name
- record.f_task_id = task_id
- record.f_task_version = task_version
- record.f_cache_name = cache_name
- rows = record.save(force_insert=True)
- if rows != 1:
- raise Exception("save cache tracking failed")
- return record.f_cache_key
- @classmethod
- @DB.connection_context()
- def query(cls, cache_key: str = None, role: str = None, party_id: int = None, component_name: str = None, cache_name: str = None,
- **kwargs) -> typing.List[DataCache]:
- if cache_key is not None:
- records = CacheRecord.query(cache_key=cache_key)
- else:
- records = CacheRecord.query(role=role, party_id=party_id, component_name=component_name,
- cache_name=cache_name, **kwargs)
- return [record.f_cache for record in records]
- @classmethod
- @DB.connection_context()
- def query_record(cls, role: str = None, party_id: int = None, component_name: str = None, **kwargs) -> typing.List[CacheRecord]:
- records = CacheRecord.query(role=role, party_id=party_id, component_name=component_name, **kwargs)
- return [record for record in records]
|