provider_app.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. from pathlib import Path
  18. from flask import request
  19. from fate_flow.db.component_registry import ComponentRegistry
  20. from fate_flow.entity import ComponentProvider, RetCode
  21. from fate_flow.entity.types import WorkerName
  22. from fate_flow.manager.worker_manager import WorkerManager
  23. from fate_flow.scheduler.cluster_scheduler import ClusterScheduler
  24. from fate_flow.utils.api_utils import (
  25. error_response, get_json_result,
  26. validate_request,
  27. )
  28. @manager.route('/update', methods=['POST'])
  29. def provider_update():
  30. request_data = request.json
  31. ComponentRegistry.load()
  32. if ComponentRegistry.get_providers().get(request_data.get("name"), {}).get(request_data.get("version"), None) is None:
  33. return get_json_result(retcode=RetCode.OPERATING_ERROR, retmsg=f"not load into memory")
  34. return get_json_result()
  35. @manager.route('/register', methods=['POST'])
  36. @validate_request("name", "version", "path")
  37. def register():
  38. info = request.json or request.form.to_dict()
  39. path = Path(info["path"]).absolute()
  40. if not path.is_dir():
  41. return error_response(400, f"path '{path}' is not a directory")
  42. if set(path.parent.iterdir()) - {path, (path.parent / "__init__.py")}:
  43. return error_response(400, f"there are other directories or files in '{path.parent}' besides '{path.name}' and '__init__.py'")
  44. provider = ComponentProvider(name=info["name"],
  45. version=info["version"],
  46. path=info["path"],
  47. class_path=info.get("class_path", ComponentRegistry.get_default_class_path()))
  48. code, std = WorkerManager.start_general_worker(worker_name=WorkerName.PROVIDER_REGISTRAR, provider=provider)
  49. if code != 0:
  50. return get_json_result(retcode=RetCode.OPERATING_ERROR, retmsg=f"register failed:\n{std}")
  51. federated_response = ClusterScheduler.cluster_command(
  52. "/provider/update",
  53. {
  54. "name": info["name"],
  55. "version": info["version"],
  56. },
  57. )
  58. return get_json_result(data=federated_response)
  59. @manager.route('/registry/get', methods=['POST'])
  60. def get_registry():
  61. return get_json_result(data=ComponentRegistry.REGISTRY)
  62. @manager.route('/get', methods=['POST'])
  63. def get_providers():
  64. providers = ComponentRegistry.get_providers()
  65. result = {}
  66. for name, group_detail in providers.items():
  67. result[name] = {}
  68. for version, detail in group_detail.items():
  69. result[name][version] = copy.deepcopy(detail)
  70. if "components" in detail:
  71. result[name][version]["components"] = set([c.lower() for c in detail["components"].keys()])
  72. return get_json_result(data=result)
  73. @manager.route('/<provider_name>/get', methods=['POST'])
  74. def get_provider(provider_name):
  75. return get_json_result(data=ComponentRegistry.get_providers().get(provider_name))