tracker_app.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from flask import request
  17. from fate_arch.common.base_utils import deserialize_b64
  18. from fate_flow.model.sync_model import SyncComponent
  19. from fate_flow.operation.job_tracker import Tracker
  20. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  21. from fate_flow.pipelined_model.pipelined_component import PipelinedComponent
  22. from fate_flow.settings import ENABLE_MODEL_STORE
  23. from fate_flow.utils.api_utils import get_json_result, validate_request
  24. from fate_flow.utils.model_utils import gen_party_model_id
  25. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/metric_data/save',
  26. methods=['POST'])
  27. def save_metric_data(job_id, component_name, task_version, task_id, role, party_id):
  28. request_data = request.json
  29. tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version,
  30. role=role, party_id=party_id)
  31. metrics = [deserialize_b64(metric) for metric in request_data['metrics']]
  32. tracker.save_metric_data(metric_namespace=request_data['metric_namespace'], metric_name=request_data['metric_name'],
  33. metrics=metrics, job_level=request_data['job_level'])
  34. return get_json_result()
  35. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/metric_meta/save',
  36. methods=['POST'])
  37. def save_metric_meta(job_id, component_name, task_version, task_id, role, party_id):
  38. request_data = request.json
  39. tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version,
  40. role=role, party_id=party_id)
  41. metric_meta = deserialize_b64(request_data['metric_meta'])
  42. tracker.save_metric_meta(metric_namespace=request_data['metric_namespace'], metric_name=request_data['metric_name'],
  43. metric_meta=metric_meta, job_level=request_data['job_level'])
  44. return get_json_result()
  45. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/table_meta/create',
  46. methods=['POST'])
  47. def create_table_meta(job_id, component_name, task_version, task_id, role, party_id):
  48. request_data = request.json
  49. tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version,
  50. role=role, party_id=party_id)
  51. tracker.save_table_meta(request_data)
  52. return get_json_result()
  53. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/table_meta/get',
  54. methods=['POST'])
  55. def get_table_meta(job_id, component_name, task_version, task_id, role, party_id):
  56. request_data = request.json
  57. tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version,
  58. role=role, party_id=party_id)
  59. table_meta_dict = tracker.get_table_meta(request_data)
  60. return get_json_result(data=table_meta_dict)
  61. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/model/save',
  62. methods=['POST'])
  63. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/component_model/save',
  64. methods=['POST'])
  65. @validate_request('model_id', 'model_version', 'component_model')
  66. def save_component_model(job_id, component_name, task_version, task_id, role, party_id):
  67. party_model_id = gen_party_model_id(request.json['model_id'], role, party_id)
  68. model_version = request.json['model_version']
  69. pipelined_model = PipelinedModel(party_model_id, model_version)
  70. pipelined_model.write_component_model(request.json['component_model'])
  71. if ENABLE_MODEL_STORE:
  72. sync_component = SyncComponent(
  73. party_model_id=party_model_id,
  74. model_version=model_version,
  75. component_name=component_name,
  76. )
  77. # no need to test sync_component.remote_exists()
  78. sync_component.upload()
  79. return get_json_result()
  80. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/model/get',
  81. methods=['POST'])
  82. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/component_model/get',
  83. methods=['POST'])
  84. @validate_request('model_id', 'model_version', 'search_model_alias')
  85. def get_component_model(job_id, component_name, task_version, task_id, role, party_id):
  86. party_model_id = gen_party_model_id(request.json['model_id'], role, party_id)
  87. model_version = request.json['model_version']
  88. if ENABLE_MODEL_STORE:
  89. sync_component = SyncComponent(
  90. party_model_id=party_model_id,
  91. model_version=model_version,
  92. component_name=component_name,
  93. )
  94. if not sync_component.local_exists() and sync_component.remote_exists():
  95. sync_component.download()
  96. pipelined_model = PipelinedModel(party_model_id, model_version)
  97. data = pipelined_model.read_component_model(component_name, request.json['search_model_alias'], False)
  98. return get_json_result(data=data)
  99. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/model/run_parameters/get',
  100. methods=['POST'])
  101. @validate_request('model_id', 'model_version')
  102. def get_component_model_run_parameters(job_id, component_name, task_version, task_id, role, party_id):
  103. pipelined_component = PipelinedComponent(
  104. role=role,
  105. party_id=party_id,
  106. model_id=request.json['model_id'],
  107. model_version=request.json['model_version'],
  108. )
  109. data = pipelined_component.get_run_parameters()
  110. return get_json_result(data=data)
  111. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/output_data_info/save',
  112. methods=['POST'])
  113. def save_output_data_info(job_id, component_name, task_version, task_id, role, party_id):
  114. request_data = request.json
  115. tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version,
  116. role=role, party_id=party_id)
  117. tracker.insert_output_data_info_into_db(data_name=request_data["data_name"],
  118. table_namespace=request_data["table_namespace"],
  119. table_name=request_data["table_name"])
  120. return get_json_result()
  121. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/output_data_info/read',
  122. methods=['POST'])
  123. def read_output_data_info(job_id, component_name, task_version, task_id, role, party_id):
  124. request_data = request.json
  125. tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version,
  126. role=role, party_id=party_id)
  127. output_data_infos = tracker.read_output_data_info_from_db(data_name=request_data["data_name"])
  128. response_data = []
  129. for output_data_info in output_data_infos:
  130. response_data.append(output_data_info.to_human_model_dict())
  131. return get_json_result(data=response_data)
  132. @manager.route('/<job_id>/<component_name>/<task_id>/<task_version>/<role>/<party_id>/summary/save',
  133. methods=['POST'])
  134. def save_component_summary(job_id: str, component_name: str, task_version: int, task_id: str, role: str, party_id: int):
  135. request_data = request.json
  136. tracker = Tracker(job_id=job_id, component_name=component_name, task_id=task_id, task_version=task_version,
  137. role=role, party_id=party_id)
  138. summary_data = request_data['summary']
  139. tracker.insert_summary_into_db(summary_data)
  140. return get_json_result()
  141. @manager.route('/<job_id>/<component_name>/<role>/<party_id>/output/table', methods=['POST'])
  142. def component_output_data_table(job_id, component_name, role, party_id):
  143. output_data_infos = Tracker.query_output_data_infos(job_id=job_id, component_name=component_name, role=role, party_id=party_id)
  144. if output_data_infos:
  145. return get_json_result(retcode=0, retmsg='success', data=[{'table_name': output_data_info.f_table_name,
  146. 'table_namespace': output_data_info.f_table_namespace,
  147. "data_name": output_data_info.f_data_name
  148. } for output_data_info in output_data_infos])
  149. else:
  150. return get_json_result(retcode=100, retmsg='No found table, please check if the parameters are correct')