task_scheduler.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_arch.common import FederatedCommunicationType
  17. from fate_flow.entity import RetCode
  18. from fate_flow.entity.run_status import StatusSet, TaskStatus, EndStatus, AutoRerunStatus, InterruptStatus
  19. from fate_flow.entity.run_status import FederatedSchedulingStatusCode
  20. from fate_flow.entity.run_status import SchedulingStatusCode
  21. from fate_flow.entity import RunParameters
  22. from fate_flow.utils import job_utils
  23. from fate_flow.scheduler.federated_scheduler import FederatedScheduler
  24. from fate_flow.operation.job_saver import JobSaver
  25. from fate_flow.utils.log_utils import schedule_logger
  26. from fate_flow.manager.resource_manager import ResourceManager
  27. from fate_flow.controller.job_controller import JobController
  28. from fate_flow.db.db_models import Job, Task
  29. from fate_flow.entity.types import TaskCleanResourceType
  30. class TaskScheduler(object):
  31. @classmethod
  32. def schedule(cls, job, dsl_parser, canceled=False):
  33. schedule_logger(job.f_job_id).info("scheduling job tasks")
  34. initiator_tasks_group = JobSaver.get_tasks_asc(job_id=job.f_job_id, role=job.f_role, party_id=job.f_party_id)
  35. waiting_tasks = []
  36. auto_rerun_tasks = []
  37. job_interrupt = False
  38. for initiator_task in initiator_tasks_group.values():
  39. if job.f_runtime_conf_on_party["job_parameters"]["federated_status_collect_type"] == FederatedCommunicationType.PULL:
  40. # collect all parties task party status and store it in the database now
  41. cls.collect_task_of_all_party(job=job, initiator_task=initiator_task)
  42. else:
  43. # all parties report task party status and store it in the initiator database when federated_status_collect_type is push
  44. pass
  45. # get all parties party task status and calculate
  46. new_task_status = cls.get_federated_task_status(job_id=initiator_task.f_job_id, task_id=initiator_task.f_task_id, task_version=initiator_task.f_task_version)
  47. task_interrupt = False
  48. task_status_have_update = False
  49. if new_task_status != initiator_task.f_status:
  50. task_status_have_update = True
  51. initiator_task.f_status = new_task_status
  52. FederatedScheduler.sync_task_status(job=job, task=initiator_task)
  53. if InterruptStatus.contains(new_task_status):
  54. task_interrupt = True
  55. job_interrupt = True
  56. if initiator_task.f_status == TaskStatus.WAITING:
  57. waiting_tasks.append(initiator_task)
  58. elif task_status_have_update and EndStatus.contains(initiator_task.f_status) or task_interrupt:
  59. command_body = {"is_asynchronous": True}
  60. schedule_logger(initiator_task.f_job_id).info(f"stop task body: {command_body}, task status: {initiator_task.f_status}")
  61. FederatedScheduler.stop_task(job=job, task=initiator_task, stop_status=initiator_task.f_status, command_body=command_body)
  62. if not canceled and AutoRerunStatus.contains(initiator_task.f_status):
  63. if initiator_task.f_auto_retries > 0:
  64. auto_rerun_tasks.append(initiator_task)
  65. schedule_logger(job.f_job_id).info(f"task {initiator_task.f_task_id} {initiator_task.f_status} will be retried")
  66. else:
  67. schedule_logger(job.f_job_id).info(f"task {initiator_task.f_task_id} {initiator_task.f_status} has no retry count")
  68. scheduling_status_code = SchedulingStatusCode.NO_NEXT
  69. schedule_logger(job.f_job_id).info(f"canceled status {canceled}, job interrupt status {job_interrupt}")
  70. if not canceled and not job_interrupt:
  71. for waiting_task in waiting_tasks:
  72. for component in dsl_parser.get_upstream_dependent_components(component_name=waiting_task.f_component_name):
  73. dependent_task = initiator_tasks_group[
  74. JobSaver.task_key(task_id=job_utils.generate_task_id(job_id=job.f_job_id, component_name=component.get_name()),
  75. role=job.f_role,
  76. party_id=job.f_party_id
  77. )
  78. ]
  79. if dependent_task.f_status != TaskStatus.SUCCESS:
  80. # can not start task
  81. break
  82. else:
  83. # all upstream dependent tasks have been successful, can start this task
  84. scheduling_status_code = SchedulingStatusCode.HAVE_NEXT
  85. status_code = cls.start_task(job=job, task=waiting_task)
  86. if status_code == SchedulingStatusCode.NO_RESOURCE:
  87. # wait for the next round of scheduling
  88. schedule_logger(job.f_job_id).info(f"task {waiting_task.f_task_id} can not apply resource, wait for the next round of scheduling")
  89. break
  90. elif status_code == SchedulingStatusCode.FAILED:
  91. scheduling_status_code = SchedulingStatusCode.FAILED
  92. waiting_task.f_status = StatusSet.FAILED
  93. FederatedScheduler.sync_task_status(job, waiting_task)
  94. break
  95. else:
  96. schedule_logger(job.f_job_id).info("have cancel signal, pass start job tasks")
  97. schedule_logger(job.f_job_id).info("finish scheduling job tasks")
  98. return scheduling_status_code, auto_rerun_tasks, initiator_tasks_group.values()
  99. @classmethod
  100. def start_task(cls, job, task):
  101. schedule_logger(task.f_job_id).info("try to start task {} {} on {} {}".format(task.f_task_id, task.f_task_version, task.f_role, task.f_party_id))
  102. apply_status = ResourceManager.apply_for_task_resource(task_info=task.to_human_model_dict(only_primary_with=["status"]))
  103. if not apply_status:
  104. return SchedulingStatusCode.NO_RESOURCE
  105. task.f_status = TaskStatus.RUNNING
  106. update_status = JobSaver.update_task_status(task_info=task.to_human_model_dict(only_primary_with=["status"]))
  107. if not update_status:
  108. # Another scheduler scheduling the task
  109. schedule_logger(task.f_job_id).info("task {} {} start on another scheduler".format(task.f_task_id, task.f_task_version))
  110. # Rollback
  111. task.f_status = TaskStatus.WAITING
  112. ResourceManager.return_task_resource(task_info=task.to_human_model_dict(only_primary_with=["status"]))
  113. return SchedulingStatusCode.PASS
  114. schedule_logger(task.f_job_id).info("start task {} {} on {} {}".format(task.f_task_id, task.f_task_version, task.f_role, task.f_party_id))
  115. FederatedScheduler.sync_task_status(job=job, task=task)
  116. status_code, response = FederatedScheduler.start_task(job=job, task=task)
  117. if status_code == FederatedSchedulingStatusCode.SUCCESS:
  118. return SchedulingStatusCode.SUCCESS
  119. else:
  120. return SchedulingStatusCode.FAILED
  121. @classmethod
  122. def prepare_rerun_task(cls, job: Job, task: Task, dsl_parser, auto=False, force=False):
  123. job_id = job.f_job_id
  124. can_rerun = False
  125. if force:
  126. can_rerun = True
  127. auto = False
  128. schedule_logger(job_id).info(f"task {task.f_task_id} {task.f_task_version} with {task.f_status} was forced to rerun")
  129. elif task.f_status in {TaskStatus.SUCCESS}:
  130. schedule_logger(job_id).info(f"task {task.f_task_id} {task.f_task_version} is {task.f_status} and not force reruen, pass rerun")
  131. elif auto and task.f_auto_retries < 1:
  132. schedule_logger(job_id).info(f"task {task.f_task_id} has no retry count, pass rerun")
  133. else:
  134. can_rerun = True
  135. if can_rerun:
  136. if task.f_status != TaskStatus.WAITING:
  137. cls.create_new_version_task(job=job,
  138. task=task,
  139. dsl_parser=dsl_parser,
  140. auto=auto)
  141. return can_rerun
  142. @classmethod
  143. def create_new_version_task(cls, job, task, dsl_parser, auto):
  144. # stop old version task
  145. FederatedScheduler.stop_task(job=job, task=task, stop_status=TaskStatus.CANCELED)
  146. FederatedScheduler.clean_task(job=job, task=task, content_type=TaskCleanResourceType.METRICS)
  147. # create new version task
  148. task.f_task_version = task.f_task_version + 1
  149. if auto:
  150. task.f_auto_retries = task.f_auto_retries - 1
  151. task.f_run_pid = None
  152. task.f_run_ip = None
  153. # todo: FederatedScheduler.create_task and JobController.initialize_tasks will create task twice
  154. status_code, response = FederatedScheduler.create_task(job=job, task=task)
  155. if status_code != FederatedSchedulingStatusCode.SUCCESS:
  156. raise Exception(f"create {task.f_task_id} new version failed")
  157. # create the task holder in db to record information of all participants in the initiator for scheduling
  158. for _role in response:
  159. for _party_id in response[_role]:
  160. if _role == job.f_initiator_role and _party_id == job.f_initiator_party_id:
  161. continue
  162. JobController.initialize_tasks(job_id=job.f_job_id,
  163. role=_role,
  164. party_id=_party_id,
  165. run_on_this_party=False,
  166. initiator_role=job.f_initiator_role,
  167. initiator_party_id=job.f_initiator_party_id,
  168. job_parameters=RunParameters(**job.f_runtime_conf_on_party["job_parameters"]),
  169. dsl_parser=dsl_parser,
  170. components=[task.f_component_name],
  171. task_version=task.f_task_version,
  172. auto_retries=task.f_auto_retries,
  173. runtime_conf=job.f_runtime_conf)
  174. schedule_logger(job.f_job_id).info(f"create task {task.f_task_id} new version {task.f_task_version} successfully")
  175. @classmethod
  176. def collect_task_of_all_party(cls, job, initiator_task, set_status=None):
  177. tasks_on_all_party = JobSaver.query_task(task_id=initiator_task.f_task_id, task_version=initiator_task.f_task_version)
  178. tasks_status_on_all = set([task.f_status for task in tasks_on_all_party])
  179. if not len(tasks_status_on_all) > 1 and TaskStatus.RUNNING not in tasks_status_on_all:
  180. return
  181. status, federated_response = FederatedScheduler.collect_task(job=job, task=initiator_task)
  182. if status != FederatedSchedulingStatusCode.SUCCESS:
  183. schedule_logger(job.f_job_id).warning(f"collect task {initiator_task.f_task_id} {initiator_task.f_task_version} on {initiator_task.f_role} {initiator_task.f_party_id} failed")
  184. for _role in federated_response.keys():
  185. for _party_id, party_response in federated_response[_role].items():
  186. if party_response["retcode"] == RetCode.SUCCESS:
  187. JobSaver.update_task_status(task_info=party_response["data"])
  188. JobSaver.update_task(task_info=party_response["data"])
  189. elif party_response["retcode"] == RetCode.FEDERATED_ERROR and set_status:
  190. tmp_task_info = {
  191. "job_id": initiator_task.f_job_id,
  192. "task_id": initiator_task.f_task_id,
  193. "task_version": initiator_task.f_task_version,
  194. "role": _role,
  195. "party_id": _party_id,
  196. "party_status": TaskStatus.RUNNING
  197. }
  198. JobSaver.update_task_status(task_info=tmp_task_info)
  199. tmp_task_info["party_status"] = set_status
  200. JobSaver.update_task_status(task_info=tmp_task_info)
  201. @classmethod
  202. def get_federated_task_status(cls, job_id, task_id, task_version):
  203. tasks_on_all_party = JobSaver.query_task(task_id=task_id, task_version=task_version)
  204. status_flag = 0
  205. # idmapping role status can only be ignored if all non-idmapping roles success
  206. for task in tasks_on_all_party:
  207. if 'idmapping' not in task.f_role and task.f_party_status != TaskStatus.SUCCESS:
  208. status_flag = 1
  209. break
  210. if status_flag:
  211. tasks_party_status = [task.f_party_status for task in tasks_on_all_party]
  212. else:
  213. tasks_party_status = [task.f_party_status for task in tasks_on_all_party if 'idmapping' not in task.f_role]
  214. status = cls.calculate_multi_party_task_status(tasks_party_status)
  215. schedule_logger(job_id=job_id).info("task {} {} status is {}, calculate by task party status list: {}".format(task_id, task_version, status, tasks_party_status))
  216. return status
  217. @classmethod
  218. def calculate_multi_party_task_status(cls, tasks_party_status):
  219. # 1. all waiting
  220. # 2. have interrupt status, should be interrupted
  221. # 3. have running
  222. # 4. waiting + success/pass
  223. # 5. all the same end status
  224. tmp_status_set = set(tasks_party_status)
  225. if TaskStatus.PASS in tmp_status_set:
  226. tmp_status_set.remove(TaskStatus.PASS)
  227. tmp_status_set.add(TaskStatus.SUCCESS)
  228. if len(tmp_status_set) == 1:
  229. # 1 and 5
  230. return tmp_status_set.pop()
  231. else:
  232. # 2
  233. for status in sorted(InterruptStatus.status_list(), key=lambda s: StatusSet.get_level(status=s), reverse=True):
  234. if status in tmp_status_set:
  235. return status
  236. # 3
  237. if TaskStatus.RUNNING in tmp_status_set:
  238. return TaskStatus.RUNNING
  239. # 4
  240. if TaskStatus.SUCCESS in tmp_status_set:
  241. return TaskStatus.RUNNING
  242. raise Exception("Calculate task status failed: {}".format(tasks_party_status))