checkpoint.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. from pathlib import Path
  18. from typing import Dict, Tuple
  19. from shutil import copytree, rmtree
  20. from base64 import b64encode
  21. from datetime import datetime
  22. from collections import deque, OrderedDict
  23. from ruamel import yaml
  24. from fate_arch.common.base_utils import json_dumps, json_loads
  25. from fate_flow.settings import stat_logger
  26. from fate_flow.entity import RunParameters
  27. from fate_flow.utils.model_utils import gen_party_model_id
  28. from fate_flow.utils.base_utils import get_fate_flow_directory
  29. from fate_flow.model import Locker
  30. class Checkpoint(Locker):
  31. def __init__(self, directory: Path, step_index: int, step_name: str):
  32. self.step_index = step_index
  33. self.step_name = step_name
  34. self.create_time = None
  35. directory = directory / f'{step_index}#{step_name}'
  36. self.database = directory / 'database.yaml'
  37. super().__init__(directory)
  38. @property
  39. def available(self):
  40. return self.database.exists()
  41. def save(self, model_buffers: Dict[str, Tuple[str, bytes, dict]]):
  42. if not model_buffers:
  43. raise ValueError('model_buffers is empty.')
  44. self.create_time = datetime.utcnow()
  45. data = {
  46. 'step_index': self.step_index,
  47. 'step_name': self.step_name,
  48. 'create_time': self.create_time.isoformat(),
  49. 'models': {},
  50. }
  51. model_data = {}
  52. for model_name, (pb_name, serialized_string, json_format_dict) in model_buffers.items():
  53. model_data[model_name] = (serialized_string, json_format_dict)
  54. data['models'][model_name] = {
  55. 'sha1': hashlib.sha1(serialized_string).hexdigest(),
  56. 'buffer_name': pb_name,
  57. }
  58. self.directory.mkdir(parents=True, exist_ok=True)
  59. with self.lock:
  60. for model_name, model in data['models'].items():
  61. serialized_string, json_format_dict = model_data[model_name]
  62. (self.directory / f'{model_name}.pb').write_bytes(serialized_string)
  63. (self.directory / f'{model_name}.json').write_text(json_dumps(json_format_dict), 'utf8')
  64. self.database.write_text(yaml.dump(data, Dumper=yaml.RoundTripDumper), 'utf8')
  65. stat_logger.info(f'Checkpoint saved. path: {self.directory}')
  66. return self.directory
  67. def read_database(self):
  68. with self.lock:
  69. data = yaml.safe_load(self.database.read_text('utf8'))
  70. if data['step_index'] != self.step_index or data['step_name'] != self.step_name:
  71. raise ValueError('Checkpoint may be incorrect: step_index or step_name dose not match. '
  72. f'filepath: {self.database} '
  73. f'expected step_index: {self.step_index} actual step_index: {data["step_index"]} '
  74. f'expected step_name: {self.step_name} actual step_index: {data["step_name"]}')
  75. self.create_time = datetime.fromisoformat(data['create_time'])
  76. return data
  77. def read(self, parse_models: bool = True, include_database: bool = False):
  78. data = self.read_database()
  79. with self.lock:
  80. for model_name, model in data['models'].items():
  81. model['filepath_pb'] = self.directory / f'{model_name}.pb'
  82. model['filepath_json'] = self.directory / f'{model_name}.json'
  83. if not model['filepath_pb'].exists() or not model['filepath_json'].exists():
  84. raise FileNotFoundError(
  85. 'Checkpoint is incorrect: protobuf file or json file not found. '
  86. f'protobuf filepath: {model["filepath_pb"]} json filepath: {model["filepath_json"]}'
  87. )
  88. model_data = {
  89. model_name: (
  90. model['filepath_pb'].read_bytes(),
  91. json_loads(model['filepath_json'].read_text('utf8')),
  92. )
  93. for model_name, model in data['models'].items()
  94. }
  95. for model_name, model in data['models'].items():
  96. serialized_string, json_format_dict = model_data[model_name]
  97. sha1 = hashlib.sha1(serialized_string).hexdigest()
  98. if sha1 != model['sha1']:
  99. raise ValueError('Checkpoint may be incorrect: hash dose not match. '
  100. f'filepath: {model["filepath"]} expected: {model["sha1"]} actual: {sha1}')
  101. data['models'] = {
  102. model_name: (
  103. model['buffer_name'],
  104. *model_data[model_name],
  105. ) if parse_models
  106. else b64encode(model_data[model_name][0]).decode('ascii')
  107. for model_name, model in data['models'].items()
  108. }
  109. return data if include_database else data['models']
  110. def remove(self):
  111. self.create_time = None
  112. rmtree(self.directory)
  113. def to_dict(self, include_models: bool = False):
  114. if not include_models:
  115. return self.read_database()
  116. return self.read(False, True)
  117. class CheckpointManager:
  118. def __init__(self, job_id: str = None, role: str = None, party_id: int = None,
  119. model_id: str = None, model_version: str = None,
  120. component_name: str = None, component_module_name: str = None,
  121. task_id: str = None, task_version: int = None,
  122. job_parameters: RunParameters = None,
  123. max_to_keep: int = None
  124. ):
  125. self.job_id = job_id
  126. self.role = role
  127. self.party_id = party_id
  128. self.model_id = model_id
  129. self.model_version = model_version
  130. self.party_model_id = gen_party_model_id(self.model_id, self.role, self.party_id)
  131. self.component_name = component_name if component_name else 'pipeline'
  132. self.module_name = component_module_name if component_module_name else 'Pipeline'
  133. self.task_id = task_id
  134. self.task_version = task_version
  135. self.job_parameters = job_parameters
  136. self.directory = (Path(get_fate_flow_directory()) / 'model_local_cache' /
  137. self.party_model_id / model_version / 'checkpoint' / self.component_name)
  138. if isinstance(max_to_keep, int):
  139. if max_to_keep <= 0:
  140. raise ValueError('max_to_keep must be positive')
  141. elif max_to_keep is not None:
  142. raise TypeError('max_to_keep must be an integer')
  143. self.checkpoints = deque(maxlen=max_to_keep)
  144. def load_checkpoints_from_disk(self):
  145. checkpoints = []
  146. for directory in self.directory.glob('*'):
  147. if not directory.is_dir() or '#' not in directory.name:
  148. continue
  149. step_index, step_name = directory.name.split('#', 1)
  150. checkpoint = Checkpoint(self.directory, int(step_index), step_name)
  151. if not checkpoint.available:
  152. continue
  153. checkpoints.append(checkpoint)
  154. self.checkpoints = deque(sorted(checkpoints, key=lambda i: i.step_index), self.max_checkpoints_number)
  155. @property
  156. def checkpoints_number(self):
  157. return len(self.checkpoints)
  158. @property
  159. def max_checkpoints_number(self):
  160. return self.checkpoints.maxlen
  161. @property
  162. def number_indexed_checkpoints(self):
  163. return OrderedDict((i.step_index, i) for i in self.checkpoints)
  164. @property
  165. def name_indexed_checkpoints(self):
  166. return OrderedDict((i.step_name, i) for i in self.checkpoints)
  167. def get_checkpoint_by_index(self, step_index: int):
  168. return self.number_indexed_checkpoints.get(step_index)
  169. def get_checkpoint_by_name(self, step_name: str):
  170. return self.name_indexed_checkpoints.get(step_name)
  171. @property
  172. def latest_checkpoint(self):
  173. if self.checkpoints:
  174. return self.checkpoints[-1]
  175. @property
  176. def latest_step_index(self):
  177. if self.latest_checkpoint is not None:
  178. return self.latest_checkpoint.step_index
  179. @property
  180. def latest_step_name(self):
  181. if self.latest_checkpoint is not None:
  182. return self.latest_checkpoint.step_name
  183. def new_checkpoint(self, step_index: int, step_name: str):
  184. if self.job_parameters is not None and self.job_parameters.job_type == 'predict':
  185. raise ValueError('Cannot create checkpoint in predict job.')
  186. popped_checkpoint = None
  187. if self.max_checkpoints_number and self.checkpoints_number >= self.max_checkpoints_number:
  188. popped_checkpoint = self.checkpoints[0]
  189. checkpoint = Checkpoint(self.directory, step_index, step_name)
  190. self.checkpoints.append(checkpoint)
  191. if popped_checkpoint is not None:
  192. popped_checkpoint.remove()
  193. return checkpoint
  194. def clean(self):
  195. self.checkpoints = deque(maxlen=self.max_checkpoints_number)
  196. rmtree(self.directory)
  197. # copy the checkpoint as a component model to the new model version
  198. def deploy(self, new_model_version: str, model_alias: str, step_index: int = None, step_name: str = None):
  199. if step_index is not None:
  200. checkpoint = self.get_checkpoint_by_index(step_index)
  201. elif step_name is not None:
  202. checkpoint = self.get_checkpoint_by_name(step_name)
  203. else:
  204. raise KeyError('step_index or step_name is required.')
  205. if checkpoint is None:
  206. raise TypeError('Checkpoint not found.')
  207. # check files hash
  208. checkpoint.read()
  209. directory = Path(get_fate_flow_directory()) / 'model_local_cache' / self.party_model_id / new_model_version
  210. target = directory / 'variables' / 'data' / self.component_name / model_alias
  211. locker = Locker(directory)
  212. with locker.lock:
  213. rmtree(target, True)
  214. copytree(checkpoint.directory, target,
  215. ignore=lambda src, names: {i for i in names if i.startswith('.')})
  216. for f in target.glob('*.pb'):
  217. f.replace(f.with_suffix(''))
  218. def to_dict(self, include_models: bool = False):
  219. return [checkpoint.to_dict(include_models) for checkpoint in self.checkpoints]