cache_manager.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import typing
  17. from uuid import uuid1
  18. from fate_arch import session, storage
  19. from fate_arch.abc import CTableABC
  20. from fate_arch.common import DTable
  21. from fate_arch.common.base_utils import current_timestamp
  22. from fate_flow.db.db_models import DB, CacheRecord
  23. from fate_flow.entity import DataCache
  24. class CacheManager:
  25. @classmethod
  26. def persistent(cls, cache_name: str, cache_data: typing.Dict[str, CTableABC], cache_meta: dict, output_namespace: str,
  27. output_name: str, output_storage_engine: str, output_storage_address: dict,
  28. token=None) -> DataCache:
  29. cache = DataCache(name=cache_name, meta=cache_meta)
  30. for name, table in cache_data.items():
  31. table_meta = session.Session.persistent(computing_table=table,
  32. namespace=output_namespace,
  33. name=f"{output_name}_{name}",
  34. schema=None,
  35. engine=output_storage_engine,
  36. engine_address=output_storage_address,
  37. token=token)
  38. cache.data[name] = DTable(namespace=table_meta.namespace, name=table_meta.name,
  39. partitions=table_meta.partitions)
  40. return cache
  41. @classmethod
  42. def load(cls, cache: DataCache) -> typing.Tuple[typing.Dict[str, CTableABC], dict]:
  43. cache_data = {}
  44. for name, table in cache.data.items():
  45. storage_table_meta = storage.StorageTableMeta(name=table.name, namespace=table.namespace)
  46. computing_table = session.get_computing_session().load(
  47. storage_table_meta.get_address(),
  48. schema=storage_table_meta.get_schema(),
  49. partitions=table.partitions)
  50. cache_data[name] = computing_table
  51. return cache_data, cache.meta
  52. @classmethod
  53. @DB.connection_context()
  54. def record(cls, cache: DataCache, job_id: str = None, role: str = None, party_id: int = None, component_name: str = None, task_id: str = None, task_version: int = None,
  55. cache_name: str = None):
  56. for attr in {"job_id", "component_name", "task_id", "task_version"}:
  57. if getattr(cache, attr) is None and locals().get(attr) is not None:
  58. setattr(cache, attr, locals().get(attr))
  59. record = CacheRecord()
  60. record.f_create_time = current_timestamp()
  61. record.f_cache_key = uuid1().hex
  62. cache.key = record.f_cache_key
  63. record.f_cache = cache
  64. record.f_job_id = job_id
  65. record.f_role = role
  66. record.f_party_id = party_id
  67. record.f_component_name = component_name
  68. record.f_task_id = task_id
  69. record.f_task_version = task_version
  70. record.f_cache_name = cache_name
  71. rows = record.save(force_insert=True)
  72. if rows != 1:
  73. raise Exception("save cache tracking failed")
  74. return record.f_cache_key
  75. @classmethod
  76. @DB.connection_context()
  77. def query(cls, cache_key: str = None, role: str = None, party_id: int = None, component_name: str = None, cache_name: str = None,
  78. **kwargs) -> typing.List[DataCache]:
  79. if cache_key is not None:
  80. records = CacheRecord.query(cache_key=cache_key)
  81. else:
  82. records = CacheRecord.query(role=role, party_id=party_id, component_name=component_name,
  83. cache_name=cache_name, **kwargs)
  84. return [record.f_cache for record in records]
  85. @classmethod
  86. @DB.connection_context()
  87. def query_record(cls, role: str = None, party_id: int = None, component_name: str = None, **kwargs) -> typing.List[CacheRecord]:
  88. records = CacheRecord.query(role=role, party_id=party_id, component_name=component_name, **kwargs)
  89. return [record for record in records]