metric_manager.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_arch.common.base_utils import current_timestamp, serialize_b64, deserialize_b64
  17. from fate_flow.utils.log_utils import schedule_logger
  18. from fate_flow.db import db_utils
  19. from fate_flow.db.db_models import (DB, TrackingMetric)
  20. from fate_flow.entity import Metric
  21. from fate_flow.utils import job_utils
  22. class MetricManager:
  23. def __init__(self, job_id: str, role: str, party_id: int,
  24. component_name: str,
  25. task_id: str = None,
  26. task_version: int = None):
  27. self.job_id = job_id
  28. self.role = role
  29. self.party_id = party_id
  30. self.component_name = component_name
  31. self.task_id = task_id
  32. self.task_version = task_version
  33. @DB.connection_context()
  34. def read_metric_data(self, metric_namespace: str, metric_name: str, job_level=False):
  35. metrics = []
  36. for k, v in self.read_metrics_from_db(metric_namespace, metric_name, 1, job_level):
  37. metrics.append(Metric(key=k, value=v))
  38. return metrics
  39. @DB.connection_context()
  40. def insert_metrics_into_db(self, metric_namespace: str, metric_name: str, data_type: int, kv, job_level=False):
  41. try:
  42. model_class = self.get_model_class()
  43. tracking_metric = model_class()
  44. tracking_metric.f_job_id = self.job_id
  45. tracking_metric.f_component_name = (self.component_name if not job_level
  46. else job_utils.PIPELINE_COMPONENT_NAME)
  47. tracking_metric.f_task_id = self.task_id
  48. tracking_metric.f_task_version = self.task_version
  49. tracking_metric.f_role = self.role
  50. tracking_metric.f_party_id = self.party_id
  51. tracking_metric.f_metric_namespace = metric_namespace
  52. tracking_metric.f_metric_name = metric_name
  53. tracking_metric.f_type = data_type
  54. default_db_source = tracking_metric.to_dict()
  55. tracking_metric_data_source = []
  56. for k, v in kv:
  57. db_source = default_db_source.copy()
  58. db_source['f_key'] = serialize_b64(k)
  59. db_source['f_value'] = serialize_b64(v)
  60. db_source['f_create_time'] = current_timestamp()
  61. tracking_metric_data_source.append(db_source)
  62. db_utils.bulk_insert_into_db(model_class, tracking_metric_data_source)
  63. except Exception as e:
  64. schedule_logger(self.job_id).exception(
  65. "An exception where inserted metric {} of metric namespace: {} to database:\n{}".format(
  66. metric_name,
  67. metric_namespace,
  68. e
  69. ))
  70. @DB.connection_context()
  71. def read_metrics_from_db(self, metric_namespace: str, metric_name: str, data_type, job_level=False):
  72. metrics = []
  73. try:
  74. tracking_metric_model = self.get_model_class()
  75. tracking_metrics = tracking_metric_model.select(tracking_metric_model.f_key,
  76. tracking_metric_model.f_value).where(
  77. tracking_metric_model.f_job_id == self.job_id,
  78. tracking_metric_model.f_component_name == (self.component_name if not job_level
  79. else job_utils.PIPELINE_COMPONENT_NAME),
  80. tracking_metric_model.f_role == self.role,
  81. tracking_metric_model.f_party_id == self.party_id,
  82. tracking_metric_model.f_metric_namespace == metric_namespace,
  83. tracking_metric_model.f_metric_name == metric_name,
  84. tracking_metric_model.f_type == data_type
  85. )
  86. for tracking_metric in tracking_metrics:
  87. yield deserialize_b64(tracking_metric.f_key), deserialize_b64(tracking_metric.f_value)
  88. except Exception as e:
  89. schedule_logger(self.job_id).exception(e)
  90. raise e
  91. return metrics
  92. @DB.connection_context()
  93. def clean_metrics(self):
  94. tracking_metric_model = self.get_model_class()
  95. operate = tracking_metric_model.delete().where(
  96. tracking_metric_model.f_task_id == self.task_id,
  97. tracking_metric_model.f_task_version == self.task_version,
  98. tracking_metric_model.f_role == self.role,
  99. tracking_metric_model.f_party_id == self.party_id
  100. )
  101. return operate.execute() > 0
  102. @DB.connection_context()
  103. def get_metric_list(self, job_level: bool = False):
  104. metrics = {}
  105. tracking_metric_model = self.get_model_class()
  106. if tracking_metric_model.table_exists():
  107. tracking_metrics = tracking_metric_model.select(
  108. tracking_metric_model.f_metric_namespace,
  109. tracking_metric_model.f_metric_name
  110. ).where(
  111. tracking_metric_model.f_job_id == self.job_id,
  112. tracking_metric_model.f_component_name == (self.component_name if not job_level else 'dag'),
  113. tracking_metric_model.f_role == self.role,
  114. tracking_metric_model.f_party_id == self.party_id
  115. ).distinct()
  116. for tracking_metric in tracking_metrics:
  117. metrics[tracking_metric.f_metric_namespace] = metrics.get(tracking_metric.f_metric_namespace, [])
  118. metrics[tracking_metric.f_metric_namespace].append(tracking_metric.f_metric_name)
  119. return metrics
  120. @DB.connection_context()
  121. def read_component_metrics(self):
  122. try:
  123. tracking_metric_model = self.get_model_class()
  124. tracking_metrics = tracking_metric_model.select().where(
  125. tracking_metric_model.f_job_id == self.job_id,
  126. tracking_metric_model.f_component_name == self.component_name,
  127. tracking_metric_model.f_role == self.role,
  128. tracking_metric_model.f_party_id == self.party_id,
  129. tracking_metric_model.f_task_version == self.task_version
  130. )
  131. return [tracking_metric for tracking_metric in tracking_metrics]
  132. except Exception as e:
  133. schedule_logger(self.job_id).exception(e)
  134. raise e
  135. @DB.connection_context()
  136. def reload_metric(self, source_metric_manager):
  137. component_metrics = source_metric_manager.read_component_metrics()
  138. for component_metric in component_metrics:
  139. model_class = self.get_model_class()
  140. tracking_metric = model_class()
  141. tracking_metric.f_job_id = self.job_id
  142. tracking_metric.f_component_name = self.component_name
  143. tracking_metric.f_task_id = self.task_id
  144. tracking_metric.f_task_version = self.task_version
  145. tracking_metric.f_role = self.role
  146. tracking_metric.f_party_id = self.party_id
  147. tracking_metric.f_metric_namespace = component_metric.f_metric_namespace
  148. tracking_metric.f_metric_name = component_metric.f_metric_name
  149. tracking_metric.f_type = component_metric.f_type
  150. tracking_metric.f_key = component_metric.f_key
  151. tracking_metric.f_value = component_metric.f_value
  152. tracking_metric.save()
  153. def get_model_class(self):
  154. return db_utils.get_dynamic_db_model(TrackingMetric, self.job_id)