model_app.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import shutil
  18. from copy import deepcopy
  19. from uuid import uuid1
  20. import peewee
  21. from flask import abort, request, send_file
  22. from fate_arch.common import FederatedMode
  23. from fate_arch.common.base_utils import json_dumps, json_loads
  24. from fate_flow.db.db_models import (
  25. DB, ModelTag, PipelineComponentMeta, Tag,
  26. MachineLearningModelInfo as MLModel,
  27. )
  28. from fate_flow.db.runtime_config import RuntimeConfig
  29. from fate_flow.db.service_registry import ServerRegistry
  30. from fate_flow.entity import JobConfigurationBase
  31. from fate_flow.entity.types import ModelOperation, TagOperation
  32. from fate_flow.model.sync_model import SyncComponent, SyncModel
  33. from fate_flow.pipelined_model import deploy_model, migrate_model, publish_model
  34. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  35. from fate_flow.scheduler.dag_scheduler import DAGScheduler
  36. from fate_flow.settings import ENABLE_MODEL_STORE, IS_STANDALONE, TEMP_DIRECTORY, stat_logger
  37. from fate_flow.utils import detect_utils, job_utils, model_utils
  38. from fate_flow.utils.api_utils import (
  39. error_response, federated_api, get_json_result,
  40. send_file_in_mem, validate_request,
  41. )
  42. from fate_flow.utils.base_utils import compare_version
  43. from fate_flow.utils.config_adapter import JobRuntimeConfigAdapter
  44. from fate_flow.utils.job_utils import PIPELINE_COMPONENT_NAME
  45. from fate_flow.utils.schedule_utils import get_dsl_parser_by_version
  46. @manager.route('/load', methods=['POST'])
  47. def load_model():
  48. request_config = request.json
  49. if request_config.get('job_id'):
  50. retcode, retmsg, data = model_utils.query_model_info(model_version=request_config['job_id'], role='guest')
  51. if not data:
  52. return get_json_result(
  53. retcode=101,
  54. retmsg=f"Model with version {request_config.get('job_id')} can not be found in database. "
  55. "Please check if the model version is valid.",
  56. )
  57. model_info = data[0]
  58. request_config['initiator'] = {}
  59. request_config['initiator']['party_id'] = str(model_info.get('f_initiator_party_id'))
  60. request_config['initiator']['role'] = model_info.get('f_initiator_role')
  61. runtime_conf = model_info.get('f_runtime_conf', {}) if model_info.get('f_runtime_conf', {}) else model_info.get('f_train_runtime_conf', {})
  62. adapter = JobRuntimeConfigAdapter(runtime_conf)
  63. job_parameters = adapter.get_common_parameters().to_dict()
  64. request_config['job_parameters'] = job_parameters if job_parameters else model_info.get('f_train_runtime_conf', {}).get('job_parameters')
  65. roles = runtime_conf.get('role')
  66. request_config['role'] = roles if roles else model_info.get('f_train_runtime_conf', {}).get('role')
  67. for key, value in request_config['role'].items():
  68. for i, v in enumerate(value):
  69. value[i] = str(v)
  70. request_config.pop('job_id')
  71. _job_id = job_utils.generate_job_id()
  72. initiator_party_id = request_config['initiator']['party_id']
  73. initiator_role = request_config['initiator']['role']
  74. publish_model.generate_publish_model_info(request_config)
  75. load_status = True
  76. load_status_info = {}
  77. load_status_msg = 'success'
  78. load_status_info['detail'] = {}
  79. if "federated_mode" not in request_config['job_parameters']:
  80. if IS_STANDALONE:
  81. request_config['job_parameters']["federated_mode"] = FederatedMode.SINGLE
  82. else:
  83. request_config['job_parameters']["federated_mode"] = FederatedMode.MULTIPLE
  84. for role_name, role_partys in request_config.get("role").items():
  85. if role_name == 'arbiter':
  86. continue
  87. load_status_info[role_name] = load_status_info.get(role_name, {})
  88. load_status_info['detail'][role_name] = {}
  89. for _party_id in role_partys:
  90. request_config['local'] = {'role': role_name, 'party_id': _party_id}
  91. try:
  92. response = federated_api(job_id=_job_id,
  93. method='POST',
  94. endpoint='/model/load/do',
  95. src_party_id=initiator_party_id,
  96. dest_party_id=_party_id,
  97. src_role = initiator_role,
  98. json_body=request_config,
  99. federated_mode=request_config['job_parameters']['federated_mode'])
  100. load_status_info[role_name][_party_id] = response['retcode']
  101. detail = {_party_id: {}}
  102. detail[_party_id]['retcode'] = response['retcode']
  103. detail[_party_id]['retmsg'] = response['retmsg']
  104. load_status_info['detail'][role_name].update(detail)
  105. if response['retcode']:
  106. load_status = False
  107. load_status_msg = 'failed'
  108. except Exception as e:
  109. stat_logger.exception(e)
  110. load_status = False
  111. load_status_msg = 'failed'
  112. load_status_info[role_name][_party_id] = 100
  113. return get_json_result(job_id=_job_id, retcode=(0 if load_status else 101), retmsg=load_status_msg,
  114. data=load_status_info)
  115. @manager.route('/migrate', methods=['POST'])
  116. @validate_request("migrate_initiator", "role", "migrate_role", "model_id",
  117. "model_version", "execute_party", "job_parameters")
  118. def migrate_model_process():
  119. request_config = request.json
  120. _job_id = job_utils.generate_job_id()
  121. initiator_party_id = request_config['migrate_initiator']['party_id']
  122. initiator_role = request_config['migrate_initiator']['role']
  123. if not request_config.get("unify_model_version"):
  124. request_config["unify_model_version"] = _job_id
  125. migrate_status = True
  126. migrate_status_info = {}
  127. migrate_status_msg = 'success'
  128. migrate_status_info['detail'] = {}
  129. try:
  130. if migrate_model.compare_roles(request_config.get("migrate_role"), request_config.get("role")):
  131. return get_json_result(retcode=100,
  132. retmsg="The config of previous roles is the same with that of migrate roles. "
  133. "There is no need to migrate model. Migration process aborting.")
  134. except Exception as e:
  135. return get_json_result(retcode=100, retmsg=str(e))
  136. local_template = {
  137. "role": "",
  138. "party_id": "",
  139. "migrate_party_id": ""
  140. }
  141. res_dict = {}
  142. for role_name, role_partys in request_config.get("migrate_role").items():
  143. for offset, party_id in enumerate(role_partys):
  144. local_res = deepcopy(local_template)
  145. local_res["role"] = role_name
  146. local_res["party_id"] = request_config.get("role").get(role_name)[offset]
  147. local_res["migrate_party_id"] = party_id
  148. if not res_dict.get(role_name):
  149. res_dict[role_name] = {}
  150. res_dict[role_name][local_res["party_id"]] = local_res
  151. for role_name, role_partys in request_config.get("execute_party").items():
  152. migrate_status_info[role_name] = migrate_status_info.get(role_name, {})
  153. migrate_status_info['detail'][role_name] = {}
  154. for party_id in role_partys:
  155. request_config["local"] = res_dict.get(role_name).get(party_id)
  156. try:
  157. response = federated_api(job_id=_job_id,
  158. method='POST',
  159. endpoint='/model/migrate/do',
  160. src_party_id=initiator_party_id,
  161. dest_party_id=party_id,
  162. src_role=initiator_role,
  163. json_body=request_config,
  164. federated_mode=request_config['job_parameters']['federated_mode'])
  165. migrate_status_info[role_name][party_id] = response['retcode']
  166. detail = {party_id: {}}
  167. detail[party_id]['retcode'] = response['retcode']
  168. detail[party_id]['retmsg'] = response['retmsg']
  169. migrate_status_info['detail'][role_name].update(detail)
  170. except Exception as e:
  171. stat_logger.exception(e)
  172. migrate_status = False
  173. migrate_status_msg = 'failed'
  174. migrate_status_info[role_name][party_id] = 100
  175. return get_json_result(job_id=_job_id, retcode=(0 if migrate_status else 101),
  176. retmsg=migrate_status_msg, data=migrate_status_info)
  177. @manager.route('/migrate/do', methods=['POST'])
  178. def do_migrate_model():
  179. request_data = request.json
  180. retcode, retmsg, data = migrate_model.migration(request_data)
  181. return get_json_result(retcode=retcode, retmsg=retmsg, data=data)
  182. @manager.route('/load/do', methods=['POST'])
  183. def do_load_model():
  184. request_data = request.json
  185. request_data['servings'] = RuntimeConfig.SERVICE_DB.get_urls('servings')
  186. role = request_data['local']['role']
  187. party_id = request_data['local']['party_id']
  188. model_id = request_data['job_parameters']['model_id']
  189. model_version = request_data['job_parameters']['model_version']
  190. if ENABLE_MODEL_STORE:
  191. sync_model = SyncModel(
  192. role=role, party_id=party_id,
  193. model_id=model_id, model_version=model_version,
  194. )
  195. if sync_model.remote_exists():
  196. sync_model.download(True)
  197. if not model_utils.check_if_deployed(role, party_id, model_id, model_version):
  198. return get_json_result(retcode=100,
  199. retmsg="Only deployed models could be used to execute process of loading. "
  200. "Please deploy model before loading.")
  201. retcode, retmsg = publish_model.load_model(request_data)
  202. try:
  203. if not retcode:
  204. with DB.connection_context():
  205. model = MLModel.get_or_none(
  206. MLModel.f_role == role,
  207. MLModel.f_party_id == party_id,
  208. MLModel.f_model_id == model_id,
  209. MLModel.f_model_version == model_version,
  210. )
  211. if model:
  212. model.f_loaded_times += 1
  213. model.save()
  214. except Exception as modify_err:
  215. stat_logger.exception(modify_err)
  216. return get_json_result(retcode=retcode, retmsg=retmsg)
  217. @manager.route('/bind', methods=['POST'])
  218. def bind_model_service():
  219. request_config = request.json
  220. if request_config.get('job_id'):
  221. retcode, retmsg, data = model_utils.query_model_info(model_version=request_config['job_id'], role='guest')
  222. if not data:
  223. return get_json_result(
  224. retcode=101,
  225. retmsg=f"Model {request_config.get('job_id')} can not be found in database. "
  226. "Please check if the model version is valid."
  227. )
  228. model_info = data[0]
  229. request_config['initiator'] = {}
  230. request_config['initiator']['party_id'] = str(model_info.get('f_initiator_party_id'))
  231. request_config['initiator']['role'] = model_info.get('f_initiator_role')
  232. runtime_conf = model_info.get('f_runtime_conf', {}) if model_info.get('f_runtime_conf', {}) else model_info.get('f_train_runtime_conf', {})
  233. adapter = JobRuntimeConfigAdapter(runtime_conf)
  234. job_parameters = adapter.get_common_parameters().to_dict()
  235. request_config['job_parameters'] = job_parameters if job_parameters else model_info.get('f_train_runtime_conf', {}).get('job_parameters')
  236. roles = runtime_conf.get('role')
  237. request_config['role'] = roles if roles else model_info.get('f_train_runtime_conf', {}).get('role')
  238. for key, value in request_config['role'].items():
  239. for i, v in enumerate(value):
  240. value[i] = str(v)
  241. request_config.pop('job_id')
  242. if not request_config.get('servings'):
  243. # get my party all servings
  244. request_config['servings'] = RuntimeConfig.SERVICE_DB.get_urls('servings')
  245. service_id = request_config.get('service_id')
  246. if not service_id:
  247. return get_json_result(retcode=101, retmsg='no service id')
  248. detect_utils.check_config(request_config, ['initiator', 'role', 'job_parameters'])
  249. bind_status, retmsg = publish_model.bind_model_service(request_config)
  250. return get_json_result(retcode=bind_status, retmsg='service id is {}'.format(service_id) if not retmsg else retmsg)
  251. @manager.route('/transfer', methods=['post'])
  252. def transfer_model():
  253. party_model_id = request.json.get('namespace')
  254. model_version = request.json.get('name')
  255. if not party_model_id or not model_version:
  256. return error_response(400, 'namespace and name are required')
  257. model_data = publish_model.download_model(party_model_id, model_version)
  258. if not model_data:
  259. return error_response(404, 'model not found')
  260. return get_json_result(data=model_data)
  261. @manager.route('/transfer/<party_model_id>/<model_version>', methods=['post'])
  262. def download_model(party_model_id, model_version):
  263. party_model_id = party_model_id.replace('~', '#')
  264. model_data = publish_model.download_model(party_model_id, model_version)
  265. if not model_data:
  266. return error_response(404, 'model not found')
  267. return get_json_result(data=model_data)
  268. @manager.route('/<model_operation>', methods=['post', 'get'])
  269. @validate_request("model_id", "model_version", "role", "party_id")
  270. def operate_model(model_operation):
  271. request_config = request.json or request.form.to_dict()
  272. job_id = job_utils.generate_job_id()
  273. # TODO: export, import, store, restore should NOT be in the same function
  274. if not ModelOperation.valid(model_operation):
  275. raise Exception(f'Not supported model operation: "{model_operation}".')
  276. model_operation = ModelOperation(model_operation)
  277. request_config['party_id'] = str(request_config['party_id'])
  278. request_config['model_version'] = str(request_config['model_version'])
  279. party_model_id = model_utils.gen_party_model_id(
  280. request_config['model_id'],
  281. request_config['role'],
  282. request_config['party_id'],
  283. )
  284. if model_operation in [ModelOperation.EXPORT, ModelOperation.IMPORT]:
  285. if model_operation is ModelOperation.IMPORT:
  286. file = request.files.get('file')
  287. if not file:
  288. return error_response(400, '`file` is required.')
  289. force_update = bool(int(request_config.get('force_update', 0)))
  290. if not force_update:
  291. with DB.connection_context():
  292. if MLModel.get_or_none(
  293. MLModel.f_role == request_config['role'],
  294. MLModel.f_party_id == request_config['party_id'],
  295. MLModel.f_model_id == request_config['model_id'],
  296. MLModel.f_model_version == request_config['model_version'],
  297. ):
  298. return error_response(409, 'Model already exists.')
  299. filename = os.path.join(TEMP_DIRECTORY, uuid1().hex)
  300. os.makedirs(os.path.dirname(filename), exist_ok=True)
  301. try:
  302. file.save(filename)
  303. except Exception as e:
  304. try:
  305. filename.unlink()
  306. except FileNotFoundError:
  307. pass
  308. return error_response(500, f'Save file error: {e}')
  309. model = PipelinedModel(party_model_id, request_config['model_version'])
  310. model.unpack_model(filename, force_update, request_config.get('hash'))
  311. pipeline = model.read_pipeline_model()
  312. train_runtime_conf = json_loads(pipeline.train_runtime_conf)
  313. for _party_id in train_runtime_conf['role'].get(request_config['role'], []):
  314. if request_config['party_id'] == str(_party_id):
  315. break
  316. else:
  317. shutil.rmtree(model.model_path, ignore_errors=True)
  318. return error_response(
  319. 400,
  320. f'Party id "{request_config["party_id"]}" is not in role "{request_config["role"]}", '
  321. f'please check if the party id and role is valid.',
  322. )
  323. model.pipelined_component.save_define_meta_from_file_to_db(force_update)
  324. if ENABLE_MODEL_STORE:
  325. query = model.pipelined_component.get_define_meta_from_db(
  326. PipelineComponentMeta.f_component_name != PIPELINE_COMPONENT_NAME,
  327. )
  328. for row in query:
  329. sync_component = SyncComponent(
  330. role=request_config['role'], party_id=request_config['party_id'],
  331. model_id=request_config['model_id'], model_version=request_config['model_version'],
  332. component_name=row.f_component_name,
  333. )
  334. sync_component.upload()
  335. pipeline.model_id = request_config['model_id']
  336. pipeline.model_version = request_config['model_version']
  337. train_runtime_conf = JobRuntimeConfigAdapter(
  338. train_runtime_conf,
  339. ).update_model_id_version(
  340. model_id=request_config['model_id'],
  341. model_version=request_config['model_version'],
  342. )
  343. if compare_version(pipeline.fate_version, '1.5.0') == 'gt':
  344. runtime_conf_on_party = json_loads(pipeline.runtime_conf_on_party)
  345. runtime_conf_on_party['job_parameters']['model_id'] = request_config['model_id']
  346. runtime_conf_on_party['job_parameters']['model_version'] = request_config['model_version']
  347. # fix migrate bug between 1.5.x and 1.8.x
  348. if compare_version(pipeline.fate_version, '1.9.0') == 'lt':
  349. pipeline.roles = json_dumps(train_runtime_conf['role'], byte=True)
  350. runtime_conf_on_party['role'] = train_runtime_conf['role']
  351. runtime_conf_on_party['initiator'] = train_runtime_conf['initiator']
  352. pipeline.runtime_conf_on_party = json_dumps(runtime_conf_on_party, byte=True)
  353. model.save_pipeline_model(pipeline, False)
  354. model_info = model_utils.gather_model_info_data(model)
  355. model_info['f_role'] = request_config['role']
  356. model_info['f_party_id'] = request_config['party_id']
  357. model_info['f_job_id'] = job_id
  358. model_info['f_imported'] = 1
  359. model_utils.save_model_info(model_info)
  360. return get_json_result(data={
  361. 'job_id': job_id,
  362. 'role': request_config['role'],
  363. 'party_id': request_config['party_id'],
  364. 'model_id': request_config['model_id'],
  365. 'model_version': request_config['model_version'],
  366. })
  367. # export
  368. else:
  369. if ENABLE_MODEL_STORE:
  370. sync_model = SyncModel(
  371. role=request_config['role'], party_id=request_config['party_id'],
  372. model_id=request_config['model_id'], model_version=request_config['model_version'],
  373. )
  374. if sync_model.remote_exists():
  375. sync_model.download(True)
  376. model = PipelinedModel(party_model_id, request_config["model_version"])
  377. if not model.exists():
  378. return error_response(404, f"Model {party_model_id} {request_config['model_version']} does not exist.")
  379. model.packaging_model()
  380. return send_file(
  381. model.archive_model_file_path,
  382. as_attachment=True,
  383. attachment_filename=os.path.basename(model.archive_model_file_path),
  384. )
  385. # store and restore
  386. else:
  387. request_config['model_id'] = party_model_id
  388. job_dsl, job_runtime_conf = gen_model_operation_job_config(request_config, model_operation)
  389. submit_result = DAGScheduler.submit(JobConfigurationBase(**{'dsl': job_dsl, 'runtime_conf': job_runtime_conf}), job_id=job_id)
  390. return get_json_result(job_id=job_id, data=submit_result)
  391. @manager.route('/model_tag/<operation>', methods=['POST'])
  392. @DB.connection_context()
  393. def tag_model(operation):
  394. if operation not in ['retrieve', 'create', 'remove']:
  395. return get_json_result(retcode=100, retmsg="'{}' is not currently supported.".format(operation))
  396. request_data = request.json
  397. model = MLModel.get_or_none(MLModel.f_model_version == request_data.get("job_id"))
  398. if not model:
  399. raise Exception("Can not found model by job id: '{}'.".format(request_data.get("job_id")))
  400. if operation == 'retrieve':
  401. res = {'tags': []}
  402. tags = (Tag.select().join(ModelTag, on=ModelTag.f_t_id == Tag.f_id).where(ModelTag.f_m_id == model.f_model_version))
  403. for tag in tags:
  404. res['tags'].append({'name': tag.f_name, 'description': tag.f_desc})
  405. res['count'] = tags.count()
  406. return get_json_result(data=res)
  407. elif operation == 'remove':
  408. tag = Tag.get_or_none(Tag.f_name == request_data.get('tag_name'))
  409. if not tag:
  410. raise Exception("Can not found '{}' tag.".format(request_data.get('tag_name')))
  411. tags = (Tag.select().join(ModelTag, on=ModelTag.f_t_id == Tag.f_id).where(ModelTag.f_m_id == model.f_model_version))
  412. if tag.f_name not in [t.f_name for t in tags]:
  413. raise Exception("Model {} {} does not have tag '{}'.".format(model.f_model_id,
  414. model.f_model_version,
  415. tag.f_name))
  416. delete_query = ModelTag.delete().where(ModelTag.f_m_id == model.f_model_version, ModelTag.f_t_id == tag.f_id)
  417. delete_query.execute()
  418. return get_json_result(retmsg="'{}' tag has been removed from tag list of model {} {}.".format(request_data.get('tag_name'),
  419. model.f_model_id,
  420. model.f_model_version))
  421. else:
  422. if not str(request_data.get('tag_name')):
  423. raise Exception("Tag name should not be an empty string.")
  424. tag = Tag.get_or_none(Tag.f_name == request_data.get('tag_name'))
  425. if not tag:
  426. tag = Tag()
  427. tag.f_name = request_data.get('tag_name')
  428. tag.save(force_insert=True)
  429. else:
  430. tags = (Tag.select().join(ModelTag, on=ModelTag.f_t_id == Tag.f_id).where(ModelTag.f_m_id == model.f_model_version))
  431. if tag.f_name in [t.f_name for t in tags]:
  432. raise Exception("Model {} {} already been tagged as tag '{}'.".format(model.f_model_id,
  433. model.f_model_version,
  434. tag.f_name))
  435. ModelTag.create(f_t_id=tag.f_id, f_m_id=model.f_model_version)
  436. return get_json_result(retmsg="Adding {} tag for model with job id: {} successfully.".format(request_data.get('tag_name'),
  437. request_data.get('job_id')))
  438. @manager.route('/tag/<tag_operation>', methods=['POST'])
  439. @DB.connection_context()
  440. def operate_tag(tag_operation):
  441. request_data = request.json
  442. if not TagOperation.valid(tag_operation):
  443. raise Exception('The {} operation is not currently supported.'.format(tag_operation))
  444. tag_name = request_data.get('tag_name')
  445. tag_desc = request_data.get('tag_desc')
  446. tag_operation = TagOperation(tag_operation)
  447. if tag_operation is TagOperation.CREATE:
  448. try:
  449. if not tag_name:
  450. return get_json_result(retcode=100, retmsg="'{}' tag created failed. Please input a valid tag name.".format(tag_name))
  451. else:
  452. Tag.create(f_name=tag_name, f_desc=tag_desc)
  453. except peewee.IntegrityError:
  454. raise Exception("'{}' has already exists in database.".format(tag_name))
  455. else:
  456. return get_json_result(retmsg="'{}' tag has been created successfully.".format(tag_name))
  457. elif tag_operation is TagOperation.LIST:
  458. tags = Tag.select()
  459. limit = request_data.get('limit')
  460. res = {"tags": []}
  461. if limit > len(tags):
  462. count = len(tags)
  463. else:
  464. count = limit
  465. for tag in tags[:count]:
  466. res['tags'].append({'name': tag.f_name, 'description': tag.f_desc,
  467. 'model_count': ModelTag.filter(ModelTag.f_t_id == tag.f_id).count()})
  468. return get_json_result(data=res)
  469. else:
  470. if not (tag_operation is TagOperation.RETRIEVE and not request_data.get('with_model')):
  471. try:
  472. tag = Tag.get(Tag.f_name == tag_name)
  473. except peewee.DoesNotExist:
  474. raise Exception("Can not found '{}' tag.".format(tag_name))
  475. if tag_operation is TagOperation.RETRIEVE:
  476. if request_data.get('with_model', False):
  477. res = {'models': []}
  478. models = (MLModel.select().join(ModelTag, on=ModelTag.f_m_id == MLModel.f_model_version).where(ModelTag.f_t_id == tag.f_id))
  479. for model in models:
  480. res["models"].append({
  481. "model_id": model.f_model_id,
  482. "model_version": model.f_model_version,
  483. "model_size": model.f_size,
  484. "role": model.f_role,
  485. "party_id": model.f_party_id
  486. })
  487. res["count"] = models.count()
  488. return get_json_result(data=res)
  489. else:
  490. tags = Tag.filter(Tag.f_name.contains(tag_name))
  491. if not tags:
  492. return get_json_result(retcode=100, retmsg="No tags found.")
  493. res = {'tags': []}
  494. for tag in tags:
  495. res['tags'].append({'name': tag.f_name, 'description': tag.f_desc})
  496. return get_json_result(data=res)
  497. elif tag_operation is TagOperation.UPDATE:
  498. new_tag_name = request_data.get('new_tag_name', None)
  499. new_tag_desc = request_data.get('new_tag_desc', None)
  500. if (tag.f_name == new_tag_name) and (tag.f_desc == new_tag_desc):
  501. return get_json_result(100, "Nothing to be updated.")
  502. else:
  503. if request_data.get('new_tag_name'):
  504. if not Tag.get_or_none(Tag.f_name == new_tag_name):
  505. tag.f_name = new_tag_name
  506. else:
  507. return get_json_result(100, retmsg="'{}' tag already exists.".format(new_tag_name))
  508. tag.f_desc = new_tag_desc
  509. tag.save()
  510. return get_json_result(retmsg="Infomation of '{}' tag has been updated successfully.".format(tag_name))
  511. else:
  512. delete_query = ModelTag.delete().where(ModelTag.f_t_id == tag.f_id)
  513. delete_query.execute()
  514. Tag.delete_instance(tag)
  515. return get_json_result(retmsg="'{}' tag has been deleted successfully.".format(tag_name))
  516. def gen_model_operation_job_config(config_data: dict, model_operation: ModelOperation):
  517. if model_operation not in {ModelOperation.STORE, ModelOperation.RESTORE}:
  518. raise Exception("Can not support this model operation: {}".format(model_operation))
  519. component_name = f"{str(model_operation).replace('.', '_').lower()}_0"
  520. job_dsl = {
  521. "components": {
  522. component_name: {
  523. "module": "Model{}".format(model_operation.value.capitalize()),
  524. },
  525. },
  526. }
  527. job_runtime_conf = job_utils.runtime_conf_basic(True)
  528. component_parameters = {
  529. "model_id": config_data["model_id"],
  530. "model_version": config_data["model_version"],
  531. "store_address": ServerRegistry.MODEL_STORE_ADDRESS,
  532. }
  533. if model_operation == ModelOperation.STORE:
  534. component_parameters["force_update"] = config_data.get("force_update", False)
  535. elif model_operation == ModelOperation.RESTORE:
  536. component_parameters["hash_"] = config_data.get("sha256", None)
  537. job_runtime_conf["component_parameters"]["role"] = {
  538. "local": {
  539. "0": {
  540. component_name: component_parameters,
  541. },
  542. },
  543. }
  544. return job_dsl, job_runtime_conf
  545. @manager.route('/query', methods=['POST'])
  546. def query_model():
  547. request_data = request.json or request.form.to_dict() or {}
  548. retcode, retmsg, data = model_utils.query_model_info(**request_data)
  549. return get_json_result(retcode=retcode, retmsg=retmsg, data=data)
  550. @manager.route('/deploy', methods=['POST'])
  551. @validate_request('model_id', 'model_version')
  552. def deploy():
  553. request_data = request.json
  554. model_id = request_data['model_id']
  555. model_version = request_data['model_version']
  556. if not isinstance(request_data.get('components_checkpoint'), dict):
  557. request_data['components_checkpoint'] = {}
  558. retcode, retmsg, data = model_utils.query_model_info(model_id=model_id, model_version=model_version)
  559. if not data:
  560. return error_response(
  561. 404,
  562. 'Deploy model failed. '
  563. f'Model {model_id} {model_version} not found.'
  564. )
  565. for model_info in data:
  566. version_check = compare_version(model_info.get('f_fate_version'), '1.5.0')
  567. if version_check == 'lt':
  568. continue
  569. initiator_role = (model_info['f_initiator_role'] if model_info.get('f_initiator_role')
  570. else model_info.get('f_train_runtime_conf', {}).get('initiator', {}).get('role', ''))
  571. initiator_party_id = (model_info['f_initiator_party_id'] if model_info.get('f_initiator_party_id')
  572. else model_info.get('f_train_runtime_conf', {}).get('initiator', {}).get('party_id', ''))
  573. if model_info['f_role'] == initiator_role and str(model_info['f_party_id']) == str(initiator_party_id):
  574. break
  575. else:
  576. return error_response(
  577. 404,
  578. 'Deploy model failed. '
  579. 'Cannot found model of initiator role or the fate version of model is older than 1.5.0',
  580. )
  581. roles = (
  582. data[0].get('f_roles') or
  583. data[0].get('f_train_runtime_conf', {}).get('role') or
  584. data[0].get('f_runtime_conf', {}).get('role')
  585. )
  586. if not roles:
  587. return error_response(
  588. 404,
  589. 'Deploy model failed. '
  590. 'Cannot found roles of model.'
  591. )
  592. # distribute federated deploy task
  593. _job_id = job_utils.generate_job_id()
  594. request_data['child_model_version'] = _job_id
  595. request_data['initiator'] = {
  596. 'role': initiator_role,
  597. 'party_id': initiator_party_id,
  598. }
  599. deploy_status = True
  600. deploy_status_info = {
  601. 'detail': {},
  602. 'model_id': model_id,
  603. 'model_version': _job_id,
  604. }
  605. for role_name, role_partys in roles.items():
  606. if role_name not in {'arbiter', 'host', 'guest'}:
  607. continue
  608. if role_name not in deploy_status_info:
  609. deploy_status_info[role_name] = {}
  610. if role_name not in deploy_status_info['detail']:
  611. deploy_status_info['detail'][role_name] = {}
  612. for _party_id in role_partys:
  613. request_data['local'] = {
  614. 'role': role_name,
  615. 'party_id': _party_id,
  616. }
  617. try:
  618. response = federated_api(
  619. job_id=_job_id,
  620. method='POST',
  621. endpoint='/model/deploy/do',
  622. src_party_id=initiator_party_id,
  623. dest_party_id=_party_id,
  624. src_role=initiator_role,
  625. json_body=request_data,
  626. federated_mode=FederatedMode.MULTIPLE if not IS_STANDALONE else FederatedMode.SINGLE
  627. )
  628. if response['retcode']:
  629. deploy_status = False
  630. deploy_status_info[role_name][_party_id] = response['retcode']
  631. deploy_status_info['detail'][role_name][_party_id] = {
  632. 'retcode': response['retcode'],
  633. 'retmsg': response['retmsg'],
  634. }
  635. except Exception as e:
  636. deploy_status = False
  637. deploy_status_info[role_name][_party_id] = 100
  638. deploy_status_info['detail'][role_name][_party_id] = {
  639. 'retcode': 100,
  640. 'retmsg': 'request failed',
  641. }
  642. stat_logger.exception(e)
  643. return get_json_result(
  644. 0 if deploy_status else 101,
  645. 'success' if deploy_status else 'failed',
  646. deploy_status_info,
  647. )
  648. @manager.route('/deploy/do', methods=['POST'])
  649. def do_deploy():
  650. retcode, retmsg = deploy_model.deploy(request.json)
  651. return get_json_result(retcode=retcode, retmsg=retmsg)
  652. def get_dsl_and_conf():
  653. request_data = request.json or request.form.to_dict() or {}
  654. request_data['query_filters'] = [
  655. 'model_id',
  656. 'model_version',
  657. 'role',
  658. 'party_id',
  659. 'train_runtime_conf',
  660. 'inference_dsl',
  661. ]
  662. retcode, retmsg, data = model_utils.query_model_info(**request_data)
  663. if not data:
  664. abort(error_response(
  665. 210,
  666. 'No model found, '
  667. 'please check if arguments are specified correctly.',
  668. ))
  669. for _data in data:
  670. if _data.get('f_role') in {'guest', 'host'}:
  671. data = _data
  672. break
  673. else:
  674. abort(error_response(
  675. 210,
  676. 'Cannot found guest or host model, '
  677. 'please get predict dsl on guest or host.',
  678. ))
  679. return request_data, data
  680. @manager.route('/get/predict/dsl', methods=['POST'])
  681. def get_predict_dsl():
  682. request_data, data = get_dsl_and_conf()
  683. if request_data.get('filename'):
  684. return send_file_in_mem(data['f_inference_dsl'], request_data['filename'])
  685. return get_json_result(data=data['f_inference_dsl'])
  686. @manager.route('/get/predict/conf', methods=['POST'])
  687. def get_predict_conf():
  688. request_data, data = get_dsl_and_conf()
  689. parser = get_dsl_parser_by_version(data['f_train_runtime_conf'].get('dsl_version', 1))
  690. conf = parser.generate_predict_conf_template(
  691. data['f_inference_dsl'], data['f_train_runtime_conf'],
  692. data['f_model_id'], data['f_model_version'],
  693. )
  694. if request_data.get('filename'):
  695. return send_file_in_mem(conf, request_data['filename'])
  696. return get_json_result(data=conf)
  697. @manager.route('/archive/packaging', methods=['POST'])
  698. @validate_request('party_model_id', 'model_version')
  699. def packaging_model():
  700. request_data = request.json or request.form.to_dict()
  701. if ENABLE_MODEL_STORE:
  702. sync_model = SyncModel(
  703. party_model_id=request_data['party_model_id'],
  704. model_version=request_data['model_version'],
  705. )
  706. if sync_model.remote_exists():
  707. sync_model.download(True)
  708. model = PipelinedModel(
  709. model_id=request_data['party_model_id'],
  710. model_version=request_data['model_version'],
  711. )
  712. if not model.exists():
  713. return error_response(404, 'Model not found.')
  714. hash_ = model.packaging_model()
  715. return get_json_result(data={
  716. 'party_model_id': model.party_model_id,
  717. 'model_version': model.model_version,
  718. 'path': model.archive_model_file_path,
  719. 'hash': hash_,
  720. })
  721. @manager.route('/service/register', methods=['POST'])
  722. @validate_request('party_model_id', 'model_version')
  723. def register_service():
  724. request_data = request.json or request.form.to_dict()
  725. RuntimeConfig.SERVICE_DB.register_model(
  726. party_model_id=request_data['party_model_id'],
  727. model_version=request_data['model_version'],
  728. )
  729. return get_json_result(data={
  730. 'party_model_id': request_data['party_model_id'],
  731. 'model_version': request_data['model_version'],
  732. })
  733. @manager.route('/homo/convert', methods=['POST'])
  734. @validate_request("model_id", "model_version", "role", "party_id")
  735. def homo_convert():
  736. request_data = request.json or request.form.to_dict()
  737. retcode, retmsg, res_data = publish_model.convert_homo_model(request_data)
  738. return get_json_result(retcode=retcode, retmsg=retmsg, data=res_data)
  739. @manager.route('/homo/deploy', methods=['POST'])
  740. @validate_request("service_id", "model_id", "model_version", "role", "party_id",
  741. "component_name", "deployment_type", "deployment_parameters")
  742. def homo_deploy():
  743. request_data = request.json or request.form.to_dict()
  744. retcode, retmsg, res_data = publish_model.deploy_homo_model(request_data)
  745. return get_json_result(retcode=retcode, retmsg=retmsg, data=res_data)