schedule_utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import typing
  17. from functools import wraps
  18. from fate_arch.common.base_utils import current_timestamp
  19. from fate_flow.db.db_models import DB, Job
  20. from fate_flow.scheduler.dsl_parser import DSLParserV1, DSLParserV2
  21. from fate_flow.utils.config_adapter import JobRuntimeConfigAdapter
  22. from fate_flow.utils.log_utils import schedule_logger
  23. @DB.connection_context()
  24. def ready_signal(job_id, set_or_reset: bool, ready_timeout_ttl=None):
  25. filters = [Job.f_job_id == job_id]
  26. if set_or_reset:
  27. update_fields = {Job.f_ready_signal: True, Job.f_ready_time: current_timestamp()}
  28. filters.append(Job.f_ready_signal == False)
  29. else:
  30. update_fields = {Job.f_ready_signal: False, Job.f_ready_time: None}
  31. filters.append(Job.f_ready_signal == True)
  32. if ready_timeout_ttl:
  33. filters.append(current_timestamp() - Job.f_ready_time > ready_timeout_ttl)
  34. update_status = Job.update(update_fields).where(*filters).execute() > 0
  35. return update_status
  36. @DB.connection_context()
  37. def cancel_signal(job_id, set_or_reset: bool):
  38. update_status = Job.update({Job.f_cancel_signal: set_or_reset, Job.f_cancel_time: current_timestamp()}).where(Job.f_job_id == job_id).execute() > 0
  39. return update_status
  40. @DB.connection_context()
  41. def rerun_signal(job_id, set_or_reset: bool):
  42. if set_or_reset is True:
  43. update_fields = {Job.f_rerun_signal: True, Job.f_cancel_signal: False, Job.f_end_scheduling_updates: 0}
  44. elif set_or_reset is False:
  45. update_fields = {Job.f_rerun_signal: False}
  46. else:
  47. raise RuntimeError(f"can not support rereun signal {set_or_reset}")
  48. update_status = Job.update(update_fields).where(Job.f_job_id == job_id).execute() > 0
  49. return update_status
  50. def schedule_lock(func):
  51. @wraps(func)
  52. def _wrapper(*args, **kwargs):
  53. _lock = kwargs.pop("lock", False)
  54. if _lock:
  55. job = kwargs.get("job")
  56. schedule_logger(job.f_job_id).info(f"get job {job.f_job_id} schedule lock")
  57. _result = None
  58. if not ready_signal(job_id=job.f_job_id, set_or_reset=True):
  59. schedule_logger(job.f_job_id).info(f"get job {job.f_job_id} schedule lock failed, job may be handled by another scheduler")
  60. return
  61. try:
  62. _result = func(*args, **kwargs)
  63. except Exception as e:
  64. raise e
  65. finally:
  66. ready_signal(job_id=job.f_job_id, set_or_reset=False)
  67. schedule_logger(job.f_job_id).info(f"release job {job.f_job_id} schedule lock")
  68. return _result
  69. else:
  70. return func(*args, **kwargs)
  71. return _wrapper
  72. @DB.connection_context()
  73. def get_job_dsl_parser_by_job_id(job_id):
  74. jobs = Job.select(Job.f_dsl, Job.f_runtime_conf_on_party, Job.f_train_runtime_conf).where(Job.f_job_id == job_id)
  75. if jobs:
  76. job = jobs[0]
  77. job_dsl_parser = get_job_dsl_parser(dsl=job.f_dsl, runtime_conf=job.f_runtime_conf_on_party,
  78. train_runtime_conf=job.f_train_runtime_conf)
  79. return job_dsl_parser, job.f_runtime_conf_on_party, job.f_dsl
  80. else:
  81. return None, None, None
  82. def get_conf_version(conf: dict):
  83. return int(conf.get("dsl_version", "1"))
  84. def get_job_dsl_parser(dsl=None, runtime_conf=None, pipeline_dsl=None, train_runtime_conf=None):
  85. parser_version = get_conf_version(runtime_conf)
  86. if parser_version == 1:
  87. dsl, runtime_conf = convert_dsl_and_conf_v1_to_v2(dsl, runtime_conf)
  88. if pipeline_dsl and train_runtime_conf:
  89. pipeline_dsl, train_runtime_conf = convert_dsl_and_conf_v1_to_v2(pipeline_dsl, train_runtime_conf)
  90. parser_version = 2
  91. dsl_parser = get_dsl_parser_by_version(parser_version)
  92. job_type = JobRuntimeConfigAdapter(runtime_conf).get_job_type()
  93. dsl_parser.run(dsl=dsl,
  94. runtime_conf=runtime_conf,
  95. pipeline_dsl=pipeline_dsl,
  96. pipeline_runtime_conf=train_runtime_conf,
  97. mode=job_type)
  98. return dsl_parser
  99. def federated_order_reset(dest_parties, scheduler_partys_info):
  100. dest_partys_new = []
  101. scheduler = []
  102. dest_party_ids_dict = {}
  103. for dest_role, dest_party_ids in dest_parties:
  104. from copy import deepcopy
  105. new_dest_party_ids = deepcopy(dest_party_ids)
  106. dest_party_ids_dict[dest_role] = new_dest_party_ids
  107. for scheduler_role, scheduler_party_id in scheduler_partys_info:
  108. if dest_role == scheduler_role and scheduler_party_id in dest_party_ids:
  109. dest_party_ids_dict[dest_role].remove(scheduler_party_id)
  110. scheduler.append((scheduler_role, [scheduler_party_id]))
  111. if dest_party_ids_dict[dest_role]:
  112. dest_partys_new.append((dest_role, dest_party_ids_dict[dest_role]))
  113. if scheduler:
  114. dest_partys_new.extend(scheduler)
  115. return dest_partys_new
  116. def get_parser_version_mapping():
  117. return {
  118. "1": DSLParserV1(),
  119. "2": DSLParserV2()
  120. }
  121. def get_dsl_parser_by_version(version: typing.Union[str, int] = 2):
  122. mapping = get_parser_version_mapping()
  123. if isinstance(version, int):
  124. version = str(version)
  125. if version not in mapping:
  126. raise Exception("{} version of dsl parser is not currently supported.".format(version))
  127. return mapping[version]
  128. def fill_inference_dsl(dsl_parser: DSLParserV2, origin_inference_dsl, components_parameters: dict = None):
  129. # must fill dsl for fate serving
  130. if isinstance(dsl_parser, DSLParserV2):
  131. components_module_name = {}
  132. for component, param in components_parameters.items():
  133. components_module_name[component] = param["CodePath"]
  134. return dsl_parser.get_predict_dsl(predict_dsl=origin_inference_dsl, module_object_dict=components_module_name)
  135. else:
  136. raise Exception(f"not support dsl parser {type(dsl_parser)}")
  137. def convert_dsl_and_conf_v1_to_v2(dsl, runtime_conf):
  138. dsl_parser_v1 = DSLParserV1()
  139. dsl = dsl_parser_v1.convert_dsl_v1_to_v2(dsl)
  140. components = dsl_parser_v1.get_components_light_weight(dsl)
  141. from fate_flow.db.component_registry import ComponentRegistry
  142. job_providers = dsl_parser_v1.get_job_providers(dsl=dsl, provider_detail=ComponentRegistry.REGISTRY)
  143. cpn_role_parameters = dict()
  144. for cpn in components:
  145. cpn_name = cpn.get_name()
  146. role_params = dsl_parser_v1.parse_component_role_parameters(
  147. component=cpn_name, dsl=dsl, runtime_conf=runtime_conf,
  148. provider_detail=ComponentRegistry.REGISTRY,
  149. provider_name=job_providers[cpn_name]["provider"]["name"],
  150. provider_version=job_providers[cpn_name]["provider"]["version"])
  151. cpn_role_parameters[cpn_name] = role_params
  152. runtime_conf = dsl_parser_v1.convert_conf_v1_to_v2(runtime_conf, cpn_role_parameters)
  153. return dsl, runtime_conf