tracker_client.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import typing
  18. from typing import List
  19. from fate_arch import storage
  20. from fate_arch.abc import AddressABC
  21. from fate_flow.utils.log_utils import getLogger
  22. from fate_flow.entity import RunParameters
  23. from fate_arch.common.base_utils import serialize_b64, deserialize_b64
  24. from fate_flow.entity import RetCode
  25. from fate_flow.entity import Metric, MetricMeta
  26. from fate_flow.operation.job_tracker import Tracker
  27. from fate_flow.utils import api_utils
  28. LOGGER = getLogger()
  29. class TrackerClient(object):
  30. def __init__(self, job_id: str, role: str, party_id: int,
  31. model_id: str = None,
  32. model_version: str = None,
  33. component_name: str = None,
  34. component_module_name: str = None,
  35. task_id: str = None,
  36. task_version: int = None,
  37. job_parameters: RunParameters = None
  38. ):
  39. self.job_id = job_id
  40. self.role = role
  41. self.party_id = party_id
  42. self.model_id = model_id
  43. self.model_version = model_version
  44. self.component_name = component_name if component_name else 'pipeline'
  45. self.module_name = component_module_name if component_module_name else 'Pipeline'
  46. self.task_id = task_id
  47. self.task_version = task_version
  48. self.job_parameters = job_parameters
  49. self.job_tracker = Tracker(job_id=job_id, role=role, party_id=party_id, component_name=component_name,
  50. task_id=task_id,
  51. task_version=task_version,
  52. model_id=model_id,
  53. model_version=model_version,
  54. job_parameters=job_parameters)
  55. def log_job_metric_data(self, metric_namespace: str, metric_name: str, metrics: List[typing.Union[Metric, dict]]):
  56. self.log_metric_data_common(metric_namespace=metric_namespace, metric_name=metric_name, metrics=metrics,
  57. job_level=True)
  58. def log_metric_data(self, metric_namespace: str, metric_name: str, metrics: List[typing.Union[Metric, dict]]):
  59. self.log_metric_data_common(metric_namespace=metric_namespace, metric_name=metric_name, metrics=metrics,
  60. job_level=False)
  61. def log_metric_data_common(self, metric_namespace: str, metric_name: str, metrics: List[typing.Union[Metric, dict]], job_level=False):
  62. LOGGER.info("Request save job {} task {} {} on {} {} metric {} {} data".format(self.job_id,
  63. self.task_id,
  64. self.task_version,
  65. self.role,
  66. self.party_id,
  67. metric_namespace,
  68. metric_name))
  69. request_body = {}
  70. request_body['metric_namespace'] = metric_namespace
  71. request_body['metric_name'] = metric_name
  72. request_body['metrics'] = [serialize_b64(metric if isinstance(metric, Metric) else Metric.from_dict(metric), to_str=True) for metric in metrics]
  73. request_body['job_level'] = job_level
  74. response = api_utils.local_api(job_id=self.job_id,
  75. method='POST',
  76. endpoint='/tracker/{}/{}/{}/{}/{}/{}/metric_data/save'.format(
  77. self.job_id,
  78. self.component_name,
  79. self.task_id,
  80. self.task_version,
  81. self.role,
  82. self.party_id),
  83. json_body=request_body)
  84. if response['retcode'] != RetCode.SUCCESS:
  85. raise Exception(f"log metric(namespace: {metric_namespace}, name: {metric_name}) data error, response code: {response['retcode']}, msg: {response['retmsg']}")
  86. def set_job_metric_meta(self, metric_namespace: str, metric_name: str, metric_meta: typing.Union[MetricMeta, dict]):
  87. self.set_metric_meta_common(metric_namespace=metric_namespace, metric_name=metric_name, metric_meta=metric_meta,
  88. job_level=True)
  89. def set_metric_meta(self, metric_namespace: str, metric_name: str, metric_meta: typing.Union[MetricMeta, dict]):
  90. self.set_metric_meta_common(metric_namespace=metric_namespace, metric_name=metric_name, metric_meta=metric_meta,
  91. job_level=False)
  92. def set_metric_meta_common(self, metric_namespace: str, metric_name: str, metric_meta: typing.Union[MetricMeta, dict], job_level=False):
  93. LOGGER.info("Request save job {} task {} {} on {} {} metric {} {} meta".format(self.job_id,
  94. self.task_id,
  95. self.task_version,
  96. self.role,
  97. self.party_id,
  98. metric_namespace,
  99. metric_name))
  100. request_body = dict()
  101. request_body['metric_namespace'] = metric_namespace
  102. request_body['metric_name'] = metric_name
  103. request_body['metric_meta'] = serialize_b64(metric_meta if isinstance(metric_meta, MetricMeta) else MetricMeta.from_dict(metric_meta), to_str=True)
  104. request_body['job_level'] = job_level
  105. response = api_utils.local_api(job_id=self.job_id,
  106. method='POST',
  107. endpoint='/tracker/{}/{}/{}/{}/{}/{}/metric_meta/save'.format(
  108. self.job_id,
  109. self.component_name,
  110. self.task_id,
  111. self.task_version,
  112. self.role,
  113. self.party_id),
  114. json_body=request_body)
  115. if response['retcode'] != RetCode.SUCCESS:
  116. raise Exception(f"log metric(namespace: {metric_namespace}, name: {metric_name}) meta error, response code: {response['retcode']}, msg: {response['retmsg']}")
  117. def create_table_meta(self, table_meta):
  118. request_body = dict()
  119. for k, v in table_meta.to_dict().items():
  120. if k == "part_of_data":
  121. request_body[k] = serialize_b64(v, to_str=True)
  122. elif k == "schema":
  123. request_body[k] = serialize_b64(v, to_str=True)
  124. elif issubclass(type(v), AddressABC):
  125. request_body[k] = v.__dict__
  126. else:
  127. request_body[k] = v
  128. response = api_utils.local_api(job_id=self.job_id,
  129. method='POST',
  130. endpoint='/tracker/{}/{}/{}/{}/{}/{}/table_meta/create'.format(
  131. self.job_id,
  132. self.component_name,
  133. self.task_id,
  134. self.task_version,
  135. self.role,
  136. self.party_id),
  137. json_body=request_body)
  138. if response['retcode'] != RetCode.SUCCESS:
  139. raise Exception(f"create table meta failed:{response['retmsg']}")
  140. def get_table_meta(self, table_name, table_namespace):
  141. request_body = {"table_name": table_name, "namespace": table_namespace}
  142. response = api_utils.local_api(job_id=self.job_id,
  143. method='POST',
  144. endpoint='/tracker/{}/{}/{}/{}/{}/{}/table_meta/get'.format(
  145. self.job_id,
  146. self.component_name,
  147. self.task_id,
  148. self.task_version,
  149. self.role,
  150. self.party_id),
  151. json_body=request_body)
  152. if response['retcode'] != RetCode.SUCCESS:
  153. raise Exception(f"create table meta failed:{response['retmsg']}")
  154. else:
  155. data_table_meta = storage.StorageTableMeta(name=table_name,
  156. namespace=table_namespace, new=True)
  157. data_table_meta.set_metas(**response["data"])
  158. data_table_meta.address = storage.StorageTableMeta.create_address(storage_engine=response["data"].get("engine"),
  159. address_dict=response["data"].get("address"))
  160. data_table_meta.part_of_data = deserialize_b64(data_table_meta.part_of_data)
  161. data_table_meta.schema = deserialize_b64(data_table_meta.schema)
  162. return data_table_meta
  163. def save_component_output_model(self, model_buffers: dict, model_alias: str, user_specified_run_parameters: dict = None):
  164. component_model = self.job_tracker.pipelined_model.create_component_model(component_name=self.component_name,
  165. component_module_name=self.module_name,
  166. model_alias=model_alias,
  167. model_buffers=model_buffers,
  168. user_specified_run_parameters=user_specified_run_parameters)
  169. json_body = {"model_id": self.model_id, "model_version": self.model_version, "component_model": component_model}
  170. response = api_utils.local_api(job_id=self.job_id,
  171. method='POST',
  172. endpoint='/tracker/{}/{}/{}/{}/{}/{}/model/save'.format(
  173. self.job_id,
  174. self.component_name,
  175. self.task_id,
  176. self.task_version,
  177. self.role,
  178. self.party_id),
  179. json_body=json_body)
  180. if response['retcode'] != RetCode.SUCCESS:
  181. raise Exception(f"save component output model failed:{response['retmsg']}")
  182. def read_component_output_model(self, search_model_alias):
  183. json_body = {"search_model_alias": search_model_alias, "model_id": self.model_id, "model_version": self.model_version}
  184. response = api_utils.local_api(job_id=self.job_id,
  185. method='POST',
  186. endpoint='/tracker/{}/{}/{}/{}/{}/{}/model/get'.format(
  187. self.job_id,
  188. self.component_name,
  189. self.task_id,
  190. self.task_version,
  191. self.role,
  192. self.party_id),
  193. json_body=json_body)
  194. if response['retcode'] != RetCode.SUCCESS:
  195. raise Exception(f"get output model failed:{response['retmsg']}")
  196. else:
  197. model_buffers = {}
  198. for model_name, v in response['data'].items():
  199. model_buffers[model_name] = (v[0], base64.b64decode(v[1].encode()))
  200. return model_buffers
  201. def get_model_run_parameters(self):
  202. json_body = {"model_id": self.model_id, "model_version": self.model_version}
  203. response = api_utils.local_api(job_id=self.job_id,
  204. method='POST',
  205. endpoint='/tracker/{}/{}/{}/{}/{}/{}/model/run_parameters/get'.format(
  206. self.job_id,
  207. self.component_name,
  208. self.task_id,
  209. self.task_version,
  210. self.role,
  211. self.party_id),
  212. json_body=json_body)
  213. if response['retcode'] != RetCode.SUCCESS:
  214. raise Exception(f"create table meta failed:{response['retmsg']}")
  215. else:
  216. return response["data"]
  217. def log_output_data_info(self, data_name: str, table_namespace: str, table_name: str):
  218. LOGGER.info("Request save job {} task {} {} on {} {} data {} info".format(self.job_id,
  219. self.task_id,
  220. self.task_version,
  221. self.role,
  222. self.party_id,
  223. data_name))
  224. request_body = dict()
  225. request_body["data_name"] = data_name
  226. request_body["table_namespace"] = table_namespace
  227. request_body["table_name"] = table_name
  228. response = api_utils.local_api(job_id=self.job_id,
  229. method='POST',
  230. endpoint='/tracker/{}/{}/{}/{}/{}/{}/output_data_info/save'.format(
  231. self.job_id,
  232. self.component_name,
  233. self.task_id,
  234. self.task_version,
  235. self.role,
  236. self.party_id),
  237. json_body=request_body)
  238. if response['retcode'] != RetCode.SUCCESS:
  239. raise Exception(f"log output data info error, response code: {response['retcode']}, msg: {response['retmsg']}")
  240. def get_output_data_info(self, data_name=None):
  241. LOGGER.info("Request read job {} task {} {} on {} {} data {} info".format(self.job_id,
  242. self.task_id,
  243. self.task_version,
  244. self.role,
  245. self.party_id,
  246. data_name))
  247. request_body = dict()
  248. request_body["data_name"] = data_name
  249. response = api_utils.local_api(job_id=self.job_id,
  250. method='POST',
  251. endpoint='/tracker/{}/{}/{}/{}/{}/{}/output_data_info/read'.format(
  252. self.job_id,
  253. self.component_name,
  254. self.task_id,
  255. self.task_version,
  256. self.role,
  257. self.party_id),
  258. json_body=request_body)
  259. if response["retcode"] == RetCode.SUCCESS and "data" in response:
  260. return response["data"]
  261. else:
  262. return None
  263. def log_component_summary(self, summary_data: dict):
  264. LOGGER.info("Request save job {} task {} {} on {} {} component summary".format(self.job_id,
  265. self.task_id,
  266. self.task_version,
  267. self.role,
  268. self.party_id))
  269. request_body = dict()
  270. request_body["summary"] = summary_data
  271. response = api_utils.local_api(job_id=self.job_id,
  272. method='POST',
  273. endpoint='/tracker/{}/{}/{}/{}/{}/{}/summary/save'.format(
  274. self.job_id,
  275. self.component_name,
  276. self.task_id,
  277. self.task_version,
  278. self.role,
  279. self.party_id),
  280. json_body=request_body)
  281. if response['retcode'] != RetCode.SUCCESS:
  282. raise Exception(f"log component summary error, response code: {response['retcode']}, msg: {response['retmsg']}")