table_app.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_arch import storage
  17. from fate_arch.metastore.db_utils import StorageConnector
  18. from fate_arch.session import Session
  19. from fate_arch.storage import StorageTableMeta, StorageTableOrigin
  20. from fate_flow.entity import RunParameters
  21. from fate_flow.manager.data_manager import DataTableTracker, TableStorage, SchemaMetaParam, AnonymousGenerator
  22. from fate_flow.operation.job_saver import JobSaver
  23. from fate_flow.operation.job_tracker import Tracker
  24. from fate_flow.utils.data_utils import get_extend_id_name, address_filter
  25. from fate_flow.worker.task_executor import TaskExecutor
  26. from fate_flow.utils.api_utils import get_json_result, error_response, validate_request
  27. from fate_flow.utils import job_utils, schedule_utils
  28. from flask import request
  29. @manager.route('/connector/create', methods=['POST'])
  30. def create_storage_connector():
  31. request_data = request.json
  32. address = StorageTableMeta.create_address(request_data.get("engine"), request_data.get("connector_info"))
  33. connector = StorageConnector(connector_name=request_data.get("connector_name"), engine=request_data.get("engine"),
  34. connector_info=address.connector)
  35. connector.create_or_update()
  36. return get_json_result(retcode=0, retmsg='success')
  37. @manager.route('/connector/query', methods=['POST'])
  38. def query_storage_connector():
  39. request_data = request.json
  40. connector = StorageConnector(connector_name=request_data.get("connector_name"))
  41. return get_json_result(retcode=0, retmsg='success', data=connector.get_info())
  42. @manager.route('/add', methods=['post'])
  43. @manager.route('/bind', methods=['post'])
  44. @validate_request("engine", "address", "namespace", "name")
  45. def table_bind():
  46. request_data = request.json
  47. address_dict = request_data.get('address')
  48. engine = request_data.get('engine')
  49. name = request_data.get('name')
  50. namespace = request_data.get('namespace')
  51. extra_schema = request_data.get("schema", {})
  52. address = storage.StorageTableMeta.create_address(storage_engine=engine, address_dict=address_dict)
  53. in_serialized = request_data.get("in_serialized", 1 if engine in {storage.StorageEngine.STANDALONE, storage.StorageEngine.EGGROLL,
  54. storage.StorageEngine.MYSQL, storage.StorageEngine.PATH,
  55. storage.StorageEngine.API} else 0)
  56. destroy = (int(request_data.get("drop", 0)) == 1)
  57. data_table_meta = storage.StorageTableMeta(name=name, namespace=namespace)
  58. if data_table_meta:
  59. if destroy:
  60. data_table_meta.destroy_metas()
  61. else:
  62. return get_json_result(retcode=100,
  63. retmsg='The data table already exists.'
  64. 'If you still want to continue uploading, please add the parameter --drop')
  65. id_column = request_data.get("id_column") or request_data.get("id_name")
  66. feature_column = request_data.get("feature_column") or request_data.get("feature_name")
  67. schema = get_bind_table_schema(id_column, feature_column)
  68. schema.update(extra_schema)
  69. if request_data.get("with_meta", False):
  70. meta = SchemaMetaParam(**request_data.get("meta", {}))
  71. if request_data.get("extend_sid", False):
  72. meta.with_match_id = True
  73. schema.update({"meta": meta.to_dict()})
  74. extra_schema["meta"] = meta.to_dict()
  75. sess = Session()
  76. storage_session = sess.storage(storage_engine=engine, options=request_data.get("options"))
  77. table = storage_session.create_table(address=address, name=name, namespace=namespace,
  78. partitions=request_data.get('partitions', None),
  79. have_head=request_data.get("head"), schema=schema,
  80. extend_sid=request_data.get("extend_sid", False),
  81. id_delimiter=request_data.get("id_delimiter"),
  82. in_serialized=in_serialized,
  83. origin=request_data.get("origin", StorageTableOrigin.TABLE_BIND)
  84. )
  85. response = get_json_result(data={"table_name": name, "namespace": namespace})
  86. if not table.check_address():
  87. response = get_json_result(retcode=100, retmsg=f'engine {engine} address {address_dict} check failed')
  88. else:
  89. if request_data.get("extend_sid"):
  90. schema = update_bind_table_schema(id_column, feature_column, request_data.get("extend_sid"), request_data.get("id_delimiter"))
  91. schema.update(extra_schema)
  92. table.meta.update_metas(schema=schema)
  93. DataTableTracker.create_table_tracker(
  94. table_name=name,
  95. table_namespace=namespace,
  96. entity_info={"have_parent": False},
  97. )
  98. sess.destroy_all_sessions()
  99. return response
  100. @manager.route('/schema/update', methods=['post'])
  101. @validate_request("schema", "namespace", "name")
  102. def schema_update():
  103. request_data = request.json
  104. data_table_meta = storage.StorageTableMeta(name=request_data.get("name"), namespace=request_data.get("namespace"))
  105. schema = data_table_meta.get_schema()
  106. if request_data.get("schema", {}).get("meta"):
  107. if schema.get("meta"):
  108. schema = AnonymousGenerator.recover_schema(schema)
  109. schema["meta"].update(request_data.get("schema").get("meta"))
  110. else:
  111. return get_json_result(retcode=101, retmsg="no found meta")
  112. request_data["schema"].pop("meta", {})
  113. schema.update(request_data.get("schema", {}))
  114. data_table_meta.update_metas(schema=schema)
  115. return get_json_result(data=schema)
  116. @manager.route('/schema/anonymous/migrate', methods=['post'])
  117. @validate_request("namespace", "name", "role", "party_id", "migrate_mapping")
  118. def meta_update():
  119. request_data = request.json
  120. data_table_meta = storage.StorageTableMeta(name=request_data.get("name"), namespace=request_data.get("namespace"))
  121. schema = data_table_meta.get_schema()
  122. update_schema = AnonymousGenerator.migrate_schema_anonymous(
  123. anonymous_schema=schema,
  124. role=request_data.get("role"),
  125. party_id=request_data.get("party_id"),
  126. migrate_mapping=request_data.get("migrate_mapping"))
  127. if update_schema:
  128. schema.update(update_schema)
  129. data_table_meta.update_metas(schema=schema)
  130. return get_json_result(data=schema)
  131. else:
  132. return get_json_result(retcode=101, retmsg="update failed")
  133. @manager.route('/download', methods=['get'])
  134. def table_download():
  135. request_data = request.json
  136. from fate_flow.component_env_utils.env_utils import import_component_output_depend
  137. import_component_output_depend()
  138. data_table_meta = storage.StorageTableMeta(name=request_data.get("name"), namespace=request_data.get("namespace"))
  139. if not data_table_meta:
  140. return error_response(response_code=210, retmsg=f'no found table:{request_data.get("namespace")}, {request_data.get("name")}')
  141. tar_file_name = 'table_{}_{}.tar.gz'.format(request_data.get("namespace"), request_data.get("name"))
  142. return TableStorage.send_table(
  143. output_tables_meta={"table": data_table_meta},
  144. tar_file_name=tar_file_name,
  145. need_head=request_data.get("head", True)
  146. )
  147. @manager.route('/delete', methods=['post'])
  148. def table_delete():
  149. request_data = request.json
  150. table_name = request_data.get('table_name')
  151. namespace = request_data.get('namespace')
  152. data = None
  153. sess = Session()
  154. table = sess.get_table(name=table_name, namespace=namespace, ignore_disable=True)
  155. if table:
  156. table.destroy()
  157. data = {'table_name': table_name, 'namespace': namespace}
  158. sess.destroy_all_sessions()
  159. if data:
  160. return get_json_result(data=data)
  161. return get_json_result(retcode=101, retmsg='no find table')
  162. @manager.route('/disable', methods=['post'])
  163. @manager.route('/enable', methods=['post'])
  164. def table_disable():
  165. request_data = request.json
  166. adapter_request_data(request_data)
  167. disable = True if request.url.endswith("disable") else False
  168. tables_meta = storage.StorageTableMeta.query_table_meta(filter_fields=dict(**request_data))
  169. data = []
  170. if tables_meta:
  171. for table_meta in tables_meta:
  172. storage.StorageTableMeta(name=table_meta.f_name,
  173. namespace=table_meta.f_namespace
  174. ).update_metas(disable=disable)
  175. data.append({'table_name': table_meta.f_name, 'namespace': table_meta.f_namespace})
  176. return get_json_result(data=data)
  177. return get_json_result(retcode=101, retmsg='no find table')
  178. @manager.route('/disable/delete', methods=['post'])
  179. def table_delete_disable():
  180. request_data = request.json
  181. adapter_request_data(request_data)
  182. tables_meta = storage.StorageTableMeta.query_table_meta(filter_fields={"disable": True})
  183. data = []
  184. sess = Session()
  185. for table_meta in tables_meta:
  186. table = sess.get_table(name=table_meta.f_name, namespace=table_meta.f_namespace, ignore_disable=True)
  187. if table:
  188. table.destroy()
  189. data.append({'table_name': table_meta.f_name, 'namespace': table_meta.f_namespace})
  190. sess.destroy_all_sessions()
  191. if data:
  192. return get_json_result(data=data)
  193. return get_json_result(retcode=101, retmsg='no find table')
  194. @manager.route('/list', methods=['post'])
  195. @validate_request('job_id', 'role', 'party_id')
  196. def get_job_table_list():
  197. jobs = JobSaver.query_job(**request.json)
  198. if jobs:
  199. job = jobs[0]
  200. tables = get_job_all_table(job)
  201. return get_json_result(data=tables)
  202. else:
  203. return get_json_result(retcode=101, retmsg='no find job')
  204. @manager.route('/<table_func>', methods=['post'])
  205. def table_api(table_func):
  206. config = request.json
  207. if table_func == 'table_info':
  208. table_key_count = 0
  209. table_partition = None
  210. table_schema = None
  211. extend_sid = False
  212. table_name, namespace = config.get("name") or config.get("table_name"), config.get("namespace")
  213. table_meta = storage.StorageTableMeta(name=table_name, namespace=namespace)
  214. address = None
  215. enable = True
  216. origin = None
  217. if table_meta:
  218. table_key_count = table_meta.get_count()
  219. table_partition = table_meta.get_partitions()
  220. table_schema = table_meta.get_schema()
  221. extend_sid = table_meta.get_extend_sid()
  222. table_schema.update()
  223. address = address_filter(table_meta.get_address())
  224. enable = not table_meta.get_disable()
  225. origin = table_meta.get_origin()
  226. exist = 1
  227. else:
  228. exist = 0
  229. return get_json_result(data={"table_name": table_name,
  230. "namespace": namespace,
  231. "exist": exist,
  232. "count": table_key_count,
  233. "partition": table_partition,
  234. "schema": table_schema,
  235. "enable": enable,
  236. "origin": origin,
  237. "extend_sid": extend_sid,
  238. "address": address})
  239. else:
  240. return get_json_result()
  241. @manager.route('/tracking/source', methods=['post'])
  242. @validate_request("table_name", "namespace")
  243. def table_tracking():
  244. request_info = request.json
  245. data = DataTableTracker.get_parent_table(request_info.get("table_name"), request_info.get("namespace"))
  246. return get_json_result(data=data)
  247. @manager.route('/tracking/job', methods=['post'])
  248. @validate_request("table_name", "namespace")
  249. def table_tracking_job():
  250. request_info = request.json
  251. data = DataTableTracker.track_job(request_info.get("table_name"), request_info.get("namespace"), display=True)
  252. return get_json_result(data=data)
  253. def get_job_all_table(job):
  254. dsl_parser = schedule_utils.get_job_dsl_parser(dsl=job.f_dsl,
  255. runtime_conf=job.f_runtime_conf,
  256. train_runtime_conf=job.f_train_runtime_conf
  257. )
  258. _, hierarchical_structure = dsl_parser.get_dsl_hierarchical_structure()
  259. component_table = {}
  260. try:
  261. component_output_tables = Tracker.query_output_data_infos(job_id=job.f_job_id, role=job.f_role,
  262. party_id=job.f_party_id)
  263. except:
  264. component_output_tables = []
  265. for component_name_list in hierarchical_structure:
  266. for component_name in component_name_list:
  267. component_table[component_name] = {}
  268. component_input_table = get_component_input_table(dsl_parser, job, component_name)
  269. component_table[component_name]['input'] = component_input_table
  270. component_table[component_name]['output'] = {}
  271. for output_table in component_output_tables:
  272. if output_table.f_component_name == component_name:
  273. component_table[component_name]['output'][output_table.f_data_name] = \
  274. {'name': output_table.f_table_name, 'namespace': output_table.f_table_namespace}
  275. return component_table
  276. def get_component_input_table(dsl_parser, job, component_name):
  277. component = dsl_parser.get_component_info(component_name=component_name)
  278. module_name = get_component_module(component_name, job.f_dsl)
  279. if 'reader' == module_name.lower():
  280. return job.f_runtime_conf.get("component_parameters", {}).get("role", {}).get(job.f_role, {}).get(str(job.f_roles.get(job.f_role).index(int(job.f_party_id)))).get(component_name)
  281. task_input_dsl = component.get_input()
  282. job_args_on_party = TaskExecutor.get_job_args_on_party(dsl_parser=dsl_parser,
  283. job_runtime_conf=job.f_runtime_conf, role=job.f_role,
  284. party_id=job.f_party_id)
  285. config = job_utils.get_job_parameters(job.f_job_id, job.f_role, job.f_party_id)
  286. task_parameters = RunParameters(**config)
  287. job_parameters = task_parameters
  288. component_input_table = TaskExecutor.get_task_run_args(job_id=job.f_job_id, role=job.f_role,
  289. party_id=job.f_party_id,
  290. task_id=None,
  291. task_version=None,
  292. job_args=job_args_on_party,
  293. job_parameters=job_parameters,
  294. task_parameters=task_parameters,
  295. input_dsl=task_input_dsl,
  296. get_input_table=True
  297. )
  298. return component_input_table
  299. def get_component_module(component_name, job_dsl):
  300. return job_dsl["components"][component_name]["module"].lower()
  301. def adapter_request_data(request_data):
  302. if request_data.get("table_name"):
  303. request_data["name"] = request_data.get("table_name")
  304. def get_bind_table_schema(id_column, feature_column):
  305. schema = {}
  306. if id_column and feature_column:
  307. schema = {'header': feature_column, 'sid': id_column}
  308. elif id_column:
  309. schema = {'sid': id_column, 'header': ''}
  310. return schema
  311. def update_bind_table_schema(id_column, feature_column, extend_sid, id_delimiter):
  312. schema = None
  313. if id_column and feature_column:
  314. schema = {'header': id_delimiter.join([id_column, feature_column]), 'sid': get_extend_id_name()}
  315. elif id_column:
  316. schema = {'header': id_column, 'sid': get_extend_id_name()}
  317. schema.update({'extend_tag': True})
  318. return schema