job_saver.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import operator
  17. import time
  18. import typing
  19. from fate_arch.common.base_utils import current_timestamp
  20. from fate_flow.db.db_models import DB, Job, Task, DataBaseModel
  21. from fate_flow.entity.run_status import JobStatus, TaskStatus, EndStatus
  22. from fate_flow.utils.log_utils import schedule_logger, sql_logger
  23. from fate_flow.utils import schedule_utils
  24. import peewee
  25. class JobSaver(object):
  26. STATUS_FIELDS = ["status", "party_status"]
  27. @classmethod
  28. def create_job(cls, job_info) -> Job:
  29. return cls.create_job_family_entity(Job, job_info)
  30. @classmethod
  31. def create_task(cls, task_info) -> Task:
  32. return cls.create_job_family_entity(Task, task_info)
  33. @classmethod
  34. @DB.connection_context()
  35. def delete_job(cls, job_id):
  36. Job.delete().where(Job.f_job_id == job_id)
  37. @classmethod
  38. def update_job_status(cls, job_info):
  39. schedule_logger(job_info["job_id"]).info("try to update job status to {}".format(job_info.get("status")))
  40. update_status = cls.update_status(Job, job_info)
  41. if update_status:
  42. schedule_logger(job_info["job_id"]).info("update job status successfully")
  43. if EndStatus.contains(job_info.get("status")):
  44. new_job_info = {}
  45. # only update tag
  46. for k in ["job_id", "role", "party_id", "tag"]:
  47. if k in job_info:
  48. new_job_info[k] = job_info[k]
  49. if not new_job_info.get("tag"):
  50. new_job_info["tag"] = "job_end"
  51. cls.update_entity_table(Job, new_job_info)
  52. else:
  53. schedule_logger(job_info["job_id"]).warning("update job status does not take effect")
  54. return update_status
  55. @classmethod
  56. def update_job(cls, job_info):
  57. schedule_logger(job_info["job_id"]).info("try to update job")
  58. if "status" in job_info:
  59. # Avoid unintentional usage that updates the status
  60. del job_info["status"]
  61. schedule_logger(job_info["job_id"]).warning("try to update job, pop job status")
  62. update_status = cls.update_entity_table(Job, job_info)
  63. if update_status:
  64. schedule_logger(job_info.get("job_id")).info(f"job update successfully: {job_info}")
  65. else:
  66. schedule_logger(job_info.get("job_id")).warning(f"job update does not take effect: {job_info}")
  67. return update_status
  68. @classmethod
  69. def update_task_status(cls, task_info):
  70. schedule_logger(task_info["job_id"]).info("try to update task {} {} status".format(task_info["task_id"], task_info["task_version"]))
  71. update_status = cls.update_status(Task, task_info)
  72. if update_status:
  73. schedule_logger(task_info["job_id"]).info("update task {} {} status successfully: {}".format(task_info["task_id"], task_info["task_version"], task_info))
  74. else:
  75. schedule_logger(task_info["job_id"]).warning("update task {} {} status update does not take effect: {}".format(task_info["task_id"], task_info["task_version"], task_info))
  76. return update_status
  77. @classmethod
  78. def update_task(cls, task_info, report=False):
  79. schedule_logger(task_info["job_id"]).info("try to update task {} {}".format(task_info["task_id"], task_info["task_version"]))
  80. update_status = cls.update_entity_table(Task, task_info)
  81. if task_info.get("error_report") and report:
  82. schedule_logger(task_info["job_id"]).error("role {} party id {} task {} error report: {}".format(
  83. task_info["role"], task_info["party_id"], task_info["task_id"], task_info["error_report"]))
  84. if update_status:
  85. schedule_logger(task_info["job_id"]).info("task {} {} update successfully".format(task_info["task_id"], task_info["task_version"]))
  86. else:
  87. schedule_logger(task_info["job_id"]).warning("task {} {} update does not take effect".format(task_info["task_id"], task_info["task_version"]))
  88. return update_status
  89. @classmethod
  90. def reload_task(cls, source_task, target_task):
  91. task_info = {"job_id": target_task.f_job_id, "task_id": target_task.f_task_id, "task_version": target_task.f_task_version,
  92. "role": target_task.f_role, "party_id": target_task.f_party_id}
  93. update_info = {}
  94. update_list = ["cmd", "elapsed", "end_date", "end_time", "engine_conf", "party_status", "run_ip",
  95. "run_pid", "start_date", "start_time", "status", "worker_id"]
  96. for k in update_list:
  97. update_info[k] = getattr(source_task, f"f_{k}")
  98. task_info.update(update_info)
  99. schedule_logger(task_info["job_id"]).info("try to update task {} {}".format(task_info["task_id"], task_info["task_version"]))
  100. schedule_logger(task_info["job_id"]).info("update info: {}".format(update_info))
  101. update_status = cls.update_entity_table(Task, task_info)
  102. if update_status:
  103. cls.update_task_status(task_info)
  104. schedule_logger(task_info["job_id"]).info("task {} {} update successfully".format(task_info["task_id"], task_info["task_version"]))
  105. else:
  106. schedule_logger(task_info["job_id"]).warning("task {} {} update does not take effect".format(task_info["task_id"], task_info["task_version"]))
  107. return update_status
  108. @classmethod
  109. @DB.connection_context()
  110. def create_job_family_entity(cls, entity_model, entity_info):
  111. obj = entity_model()
  112. obj.f_create_time = current_timestamp()
  113. for k, v in entity_info.items():
  114. attr_name = 'f_%s' % k
  115. if hasattr(entity_model, attr_name):
  116. setattr(obj, attr_name, v)
  117. try:
  118. rows = obj.save(force_insert=True)
  119. if rows != 1:
  120. raise Exception("Create {} failed".format(entity_model))
  121. return obj
  122. except peewee.IntegrityError as e:
  123. if e.args[0] == 1062 or (isinstance(e.args[0], str) and "UNIQUE constraint failed" in e.args[0]):
  124. sql_logger(job_id=entity_info.get("job_id", "fate_flow")).warning(e)
  125. else:
  126. raise Exception("Create {} failed:\n{}".format(entity_model, e))
  127. except Exception as e:
  128. raise Exception("Create {} failed:\n{}".format(entity_model, e))
  129. @classmethod
  130. @DB.connection_context()
  131. def update_status(cls, entity_model: DataBaseModel, entity_info: dict):
  132. query_filters = []
  133. primary_keys = entity_model.get_primary_keys_name()
  134. for p_k in primary_keys:
  135. query_filters.append(operator.attrgetter(p_k)(entity_model) == entity_info[p_k[2:]])
  136. objs = entity_model.select().where(*query_filters)
  137. if not objs:
  138. raise Exception(f"can not found the {entity_model.__name__} record to update")
  139. obj = objs[0]
  140. update_filters = query_filters.copy()
  141. update_info = {"job_id": entity_info["job_id"]}
  142. for status_field in cls.STATUS_FIELDS:
  143. if entity_info.get(status_field) and hasattr(entity_model, f"f_{status_field}"):
  144. if status_field in ["status", "party_status"]:
  145. update_info[status_field] = entity_info[status_field]
  146. old_status = getattr(obj, f"f_{status_field}")
  147. new_status = update_info[status_field]
  148. if_pass = False
  149. if isinstance(obj, Task):
  150. if TaskStatus.StateTransitionRule.if_pass(src_status=old_status, dest_status=new_status):
  151. if_pass = True
  152. elif isinstance(obj, Job):
  153. if JobStatus.StateTransitionRule.if_pass(src_status=old_status, dest_status=new_status):
  154. if_pass = True
  155. if EndStatus.contains(new_status) and new_status not in {JobStatus.SUCCESS, JobStatus.CANCELED}:
  156. update_filters.append(Job.f_rerun_signal == False)
  157. if if_pass:
  158. update_filters.append(operator.attrgetter(f"f_{status_field}")(type(obj)) == old_status)
  159. else:
  160. # not allow update status
  161. update_info.pop(status_field)
  162. return cls.execute_update(old_obj=obj, model=entity_model, update_info=update_info, update_filters=update_filters)
  163. @classmethod
  164. @DB.connection_context()
  165. def update_entity_table(cls, entity_model, entity_info):
  166. query_filters = []
  167. primary_keys = entity_model.get_primary_keys_name()
  168. for p_k in primary_keys:
  169. query_filters.append(operator.attrgetter(p_k)(entity_model) == entity_info[p_k.lstrip("f").lstrip("_")])
  170. objs = entity_model.select().where(*query_filters)
  171. if objs:
  172. obj = objs[0]
  173. else:
  174. raise Exception("can not found the {}".format(entity_model.__name__))
  175. update_filters = query_filters[:]
  176. update_info = {}
  177. update_info.update(entity_info)
  178. for _ in cls.STATUS_FIELDS:
  179. # not allow update status fields by this function
  180. update_info.pop(_, None)
  181. if update_info.get("tag") in {"job_end", "submit_failed"} and hasattr(entity_model, "f_tag"):
  182. if obj.f_start_time:
  183. update_info["end_time"] = current_timestamp()
  184. update_info['elapsed'] = update_info['end_time'] - obj.f_start_time
  185. if update_info.get("progress") and hasattr(entity_model, "f_progress") and update_info["progress"] > 0:
  186. update_filters.append(operator.attrgetter("f_progress")(entity_model) <= update_info["progress"])
  187. return cls.execute_update(old_obj=obj, model=entity_model, update_info=update_info, update_filters=update_filters)
  188. @classmethod
  189. def execute_update(cls, old_obj, model, update_info, update_filters):
  190. update_fields = {}
  191. for k, v in update_info.items():
  192. attr_name = 'f_%s' % k
  193. if hasattr(model, attr_name) and attr_name not in model.get_primary_keys_name():
  194. update_fields[operator.attrgetter(attr_name)(model)] = v
  195. if update_fields:
  196. if update_filters:
  197. operate = old_obj.update(update_fields).where(*update_filters)
  198. else:
  199. operate = old_obj.update(update_fields)
  200. sql_logger(job_id=update_info.get("job_id", "fate_flow")).info(operate)
  201. return operate.execute() > 0
  202. else:
  203. return False
  204. @classmethod
  205. @DB.connection_context()
  206. def query_job(cls, reverse=None, order_by=None, **kwargs):
  207. return Job.query(reverse=reverse, order_by=order_by, **kwargs)
  208. @classmethod
  209. @DB.connection_context()
  210. def get_tasks_asc(cls, job_id, role, party_id):
  211. tasks = Task.query(order_by="create_time", reverse=False, job_id=job_id, role=role, party_id=party_id)
  212. tasks_group = cls.get_latest_tasks(tasks=tasks)
  213. return tasks_group
  214. @classmethod
  215. @DB.connection_context()
  216. def query_task(cls, only_latest=True, reverse=None, order_by=None, **kwargs) -> typing.List[Task]:
  217. tasks = Task.query(reverse=reverse, order_by=order_by, **kwargs)
  218. if only_latest:
  219. tasks_group = cls.get_latest_tasks(tasks=tasks)
  220. return list(tasks_group.values())
  221. else:
  222. return tasks
  223. @classmethod
  224. @DB.connection_context()
  225. def check_task(cls, job_id, role, party_id, components: list):
  226. filters = [
  227. Task.f_job_id == job_id,
  228. Task.f_role == role,
  229. Task.f_party_id == party_id,
  230. Task.f_component_name << components
  231. ]
  232. tasks = Task.select().where(*filters)
  233. if tasks and len(tasks) == len(components):
  234. return True
  235. else:
  236. return False
  237. @classmethod
  238. def get_latest_tasks(cls, tasks):
  239. tasks_group = {}
  240. for task in tasks:
  241. task_key = cls.task_key(task_id=task.f_task_id, role=task.f_role, party_id=task.f_party_id)
  242. if task_key not in tasks_group:
  243. tasks_group[task_key] = task
  244. elif task.f_task_version > tasks_group[task_key].f_task_version:
  245. # update new version task
  246. tasks_group[task_key] = task
  247. return tasks_group
  248. @classmethod
  249. def fill_job_inference_dsl(cls, job_id, role, party_id, dsl_parser, origin_inference_dsl):
  250. # must fill dsl for fate serving
  251. components_parameters = {}
  252. tasks = cls.query_task(job_id=job_id, role=role, party_id=party_id, only_latest=True)
  253. for task in tasks:
  254. components_parameters[task.f_component_name] = task.f_component_parameters
  255. return schedule_utils.fill_inference_dsl(dsl_parser, origin_inference_dsl=origin_inference_dsl, components_parameters=components_parameters)
  256. @classmethod
  257. def task_key(cls, task_id, role, party_id):
  258. return f"{task_id}_{role}_{party_id}"
  259. def str_to_time_stamp(time_str):
  260. time_array = time.strptime(time_str, "%Y-%m-%d %H:%M:%S")
  261. time_stamp = int(time.mktime(time_array) * 1000)
  262. return time_stamp