cache_loader.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_flow.utils.log_utils import getLogger
  17. from fate_flow.components._base import (
  18. BaseParam,
  19. ComponentBase,
  20. ComponentMeta,
  21. ComponentInputProtocol,
  22. )
  23. from fate_flow.operation.job_tracker import Tracker
  24. from fate_flow.entity import MetricMeta, MetricType
  25. LOGGER = getLogger()
  26. cache_loader_cpn_meta = ComponentMeta("CacheLoader")
  27. @cache_loader_cpn_meta.bind_param
  28. class CacheLoaderParam(BaseParam):
  29. def __init__(self, cache_key=None, job_id=None, component_name=None, cache_name=None):
  30. super().__init__()
  31. self.cache_key = cache_key
  32. self.job_id = job_id
  33. self.component_name = component_name
  34. self.cache_name = cache_name
  35. def check(self):
  36. return True
  37. @cache_loader_cpn_meta.bind_runner.on_guest.on_host
  38. class CacheLoader(ComponentBase):
  39. def __init__(self):
  40. super(CacheLoader, self).__init__()
  41. self.parameters = {}
  42. self.cache_key = None
  43. self.job_id = None
  44. self.component_name = None
  45. self.cache_name = None
  46. def _run(self, cpn_input: ComponentInputProtocol):
  47. self.parameters = cpn_input.parameters
  48. LOGGER.info(self.parameters)
  49. for k, v in self.parameters.items():
  50. if hasattr(self, k):
  51. setattr(self, k, v)
  52. tracker = Tracker(job_id=self.job_id,
  53. role=self.tracker.role,
  54. party_id=self.tracker.party_id,
  55. component_name=self.component_name)
  56. LOGGER.info(f"query cache by cache key: {self.cache_key} cache name: {self.cache_name}")
  57. # todo: use tracker client but not tracker
  58. caches = tracker.query_output_cache(cache_key=self.cache_key, cache_name=self.cache_name)
  59. if not caches:
  60. raise Exception("can not found this cache")
  61. elif len(caches) > 1:
  62. raise Exception(f"found {len(caches)} caches, only support one, please check parameters")
  63. else:
  64. cache = caches[0]
  65. self.cache_output = cache
  66. tracker.job_id = self.tracker.job_id
  67. tracker.component_name = self.tracker.component_name
  68. metric_meta = cache.to_dict()
  69. metric_meta.pop("data")
  70. metric_meta["component_name"] = self.component_name
  71. self.tracker.set_metric_meta(metric_namespace="cache_loader", metric_name=cache.name, metric_meta=MetricMeta(name="cache", metric_type=MetricType.CACHE_INFO, extra_metas=metric_meta))