123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from fate_arch.common.base_utils import current_timestamp, serialize_b64, deserialize_b64
- from fate_flow.utils.log_utils import schedule_logger
- from fate_flow.db import db_utils
- from fate_flow.db.db_models import (DB, TrackingMetric)
- from fate_flow.entity import Metric
- from fate_flow.utils import job_utils
- class MetricManager:
- def __init__(self, job_id: str, role: str, party_id: int,
- component_name: str,
- task_id: str = None,
- task_version: int = None):
- self.job_id = job_id
- self.role = role
- self.party_id = party_id
- self.component_name = component_name
- self.task_id = task_id
- self.task_version = task_version
- @DB.connection_context()
- def read_metric_data(self, metric_namespace: str, metric_name: str, job_level=False):
- metrics = []
- for k, v in self.read_metrics_from_db(metric_namespace, metric_name, 1, job_level):
- metrics.append(Metric(key=k, value=v))
- return metrics
- @DB.connection_context()
- def insert_metrics_into_db(self, metric_namespace: str, metric_name: str, data_type: int, kv, job_level=False):
- try:
- model_class = self.get_model_class()
- tracking_metric = model_class()
- tracking_metric.f_job_id = self.job_id
- tracking_metric.f_component_name = (self.component_name if not job_level
- else job_utils.PIPELINE_COMPONENT_NAME)
- tracking_metric.f_task_id = self.task_id
- tracking_metric.f_task_version = self.task_version
- tracking_metric.f_role = self.role
- tracking_metric.f_party_id = self.party_id
- tracking_metric.f_metric_namespace = metric_namespace
- tracking_metric.f_metric_name = metric_name
- tracking_metric.f_type = data_type
- default_db_source = tracking_metric.to_dict()
- tracking_metric_data_source = []
- for k, v in kv:
- db_source = default_db_source.copy()
- db_source['f_key'] = serialize_b64(k)
- db_source['f_value'] = serialize_b64(v)
- db_source['f_create_time'] = current_timestamp()
- tracking_metric_data_source.append(db_source)
- db_utils.bulk_insert_into_db(model_class, tracking_metric_data_source)
- except Exception as e:
- schedule_logger(self.job_id).exception(
- "An exception where inserted metric {} of metric namespace: {} to database:\n{}".format(
- metric_name,
- metric_namespace,
- e
- ))
- @DB.connection_context()
- def read_metrics_from_db(self, metric_namespace: str, metric_name: str, data_type, job_level=False):
- metrics = []
- try:
- tracking_metric_model = self.get_model_class()
- tracking_metrics = tracking_metric_model.select(tracking_metric_model.f_key,
- tracking_metric_model.f_value).where(
- tracking_metric_model.f_job_id == self.job_id,
- tracking_metric_model.f_component_name == (self.component_name if not job_level
- else job_utils.PIPELINE_COMPONENT_NAME),
- tracking_metric_model.f_role == self.role,
- tracking_metric_model.f_party_id == self.party_id,
- tracking_metric_model.f_metric_namespace == metric_namespace,
- tracking_metric_model.f_metric_name == metric_name,
- tracking_metric_model.f_type == data_type
- )
- for tracking_metric in tracking_metrics:
- yield deserialize_b64(tracking_metric.f_key), deserialize_b64(tracking_metric.f_value)
- except Exception as e:
- schedule_logger(self.job_id).exception(e)
- raise e
- return metrics
- @DB.connection_context()
- def clean_metrics(self):
- tracking_metric_model = self.get_model_class()
- operate = tracking_metric_model.delete().where(
- tracking_metric_model.f_task_id == self.task_id,
- tracking_metric_model.f_task_version == self.task_version,
- tracking_metric_model.f_role == self.role,
- tracking_metric_model.f_party_id == self.party_id
- )
- return operate.execute() > 0
- @DB.connection_context()
- def get_metric_list(self, job_level: bool = False):
- metrics = {}
- tracking_metric_model = self.get_model_class()
- if tracking_metric_model.table_exists():
- tracking_metrics = tracking_metric_model.select(
- tracking_metric_model.f_metric_namespace,
- tracking_metric_model.f_metric_name
- ).where(
- tracking_metric_model.f_job_id == self.job_id,
- tracking_metric_model.f_component_name == (self.component_name if not job_level else 'dag'),
- tracking_metric_model.f_role == self.role,
- tracking_metric_model.f_party_id == self.party_id
- ).distinct()
- for tracking_metric in tracking_metrics:
- metrics[tracking_metric.f_metric_namespace] = metrics.get(tracking_metric.f_metric_namespace, [])
- metrics[tracking_metric.f_metric_namespace].append(tracking_metric.f_metric_name)
- return metrics
- @DB.connection_context()
- def read_component_metrics(self):
- try:
- tracking_metric_model = self.get_model_class()
- tracking_metrics = tracking_metric_model.select().where(
- tracking_metric_model.f_job_id == self.job_id,
- tracking_metric_model.f_component_name == self.component_name,
- tracking_metric_model.f_role == self.role,
- tracking_metric_model.f_party_id == self.party_id,
- tracking_metric_model.f_task_version == self.task_version
- )
- return [tracking_metric for tracking_metric in tracking_metrics]
- except Exception as e:
- schedule_logger(self.job_id).exception(e)
- raise e
- @DB.connection_context()
- def reload_metric(self, source_metric_manager):
- component_metrics = source_metric_manager.read_component_metrics()
- for component_metric in component_metrics:
- model_class = self.get_model_class()
- tracking_metric = model_class()
- tracking_metric.f_job_id = self.job_id
- tracking_metric.f_component_name = self.component_name
- tracking_metric.f_task_id = self.task_id
- tracking_metric.f_task_version = self.task_version
- tracking_metric.f_role = self.role
- tracking_metric.f_party_id = self.party_id
- tracking_metric.f_metric_namespace = component_metric.f_metric_namespace
- tracking_metric.f_metric_name = component_metric.f_metric_name
- tracking_metric.f_type = component_metric.f_type
- tracking_metric.f_key = component_metric.f_key
- tracking_metric.f_value = component_metric.f_value
- tracking_metric.save()
- def get_model_class(self):
- return db_utils.get_dynamic_db_model(TrackingMetric, self.job_id)
|