123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import hashlib
- from pathlib import Path
- from typing import Dict, Tuple
- from shutil import copytree, rmtree
- from base64 import b64encode
- from datetime import datetime
- from collections import deque, OrderedDict
- from ruamel import yaml
- from fate_arch.common.base_utils import json_dumps, json_loads
- from fate_flow.settings import stat_logger
- from fate_flow.entity import RunParameters
- from fate_flow.utils.model_utils import gen_party_model_id
- from fate_flow.utils.base_utils import get_fate_flow_directory
- from fate_flow.model import Locker
- class Checkpoint(Locker):
- def __init__(self, directory: Path, step_index: int, step_name: str):
- self.step_index = step_index
- self.step_name = step_name
- self.create_time = None
- directory = directory / f'{step_index}#{step_name}'
- self.database = directory / 'database.yaml'
- super().__init__(directory)
- @property
- def available(self):
- return self.database.exists()
- def save(self, model_buffers: Dict[str, Tuple[str, bytes, dict]]):
- if not model_buffers:
- raise ValueError('model_buffers is empty.')
- self.create_time = datetime.utcnow()
- data = {
- 'step_index': self.step_index,
- 'step_name': self.step_name,
- 'create_time': self.create_time.isoformat(),
- 'models': {},
- }
- model_data = {}
- for model_name, (pb_name, serialized_string, json_format_dict) in model_buffers.items():
- model_data[model_name] = (serialized_string, json_format_dict)
- data['models'][model_name] = {
- 'sha1': hashlib.sha1(serialized_string).hexdigest(),
- 'buffer_name': pb_name,
- }
- self.directory.mkdir(parents=True, exist_ok=True)
- with self.lock:
- for model_name, model in data['models'].items():
- serialized_string, json_format_dict = model_data[model_name]
- (self.directory / f'{model_name}.pb').write_bytes(serialized_string)
- (self.directory / f'{model_name}.json').write_text(json_dumps(json_format_dict), 'utf8')
- self.database.write_text(yaml.dump(data, Dumper=yaml.RoundTripDumper), 'utf8')
- stat_logger.info(f'Checkpoint saved. path: {self.directory}')
- return self.directory
- def read_database(self):
- with self.lock:
- data = yaml.safe_load(self.database.read_text('utf8'))
- if data['step_index'] != self.step_index or data['step_name'] != self.step_name:
- raise ValueError('Checkpoint may be incorrect: step_index or step_name dose not match. '
- f'filepath: {self.database} '
- f'expected step_index: {self.step_index} actual step_index: {data["step_index"]} '
- f'expected step_name: {self.step_name} actual step_index: {data["step_name"]}')
- self.create_time = datetime.fromisoformat(data['create_time'])
- return data
- def read(self, parse_models: bool = True, include_database: bool = False):
- data = self.read_database()
- with self.lock:
- for model_name, model in data['models'].items():
- model['filepath_pb'] = self.directory / f'{model_name}.pb'
- model['filepath_json'] = self.directory / f'{model_name}.json'
- if not model['filepath_pb'].exists() or not model['filepath_json'].exists():
- raise FileNotFoundError(
- 'Checkpoint is incorrect: protobuf file or json file not found. '
- f'protobuf filepath: {model["filepath_pb"]} json filepath: {model["filepath_json"]}'
- )
- model_data = {
- model_name: (
- model['filepath_pb'].read_bytes(),
- json_loads(model['filepath_json'].read_text('utf8')),
- )
- for model_name, model in data['models'].items()
- }
- for model_name, model in data['models'].items():
- serialized_string, json_format_dict = model_data[model_name]
- sha1 = hashlib.sha1(serialized_string).hexdigest()
- if sha1 != model['sha1']:
- raise ValueError('Checkpoint may be incorrect: hash dose not match. '
- f'filepath: {model["filepath"]} expected: {model["sha1"]} actual: {sha1}')
- data['models'] = {
- model_name: (
- model['buffer_name'],
- *model_data[model_name],
- ) if parse_models
- else b64encode(model_data[model_name][0]).decode('ascii')
- for model_name, model in data['models'].items()
- }
- return data if include_database else data['models']
- def remove(self):
- self.create_time = None
- rmtree(self.directory)
- def to_dict(self, include_models: bool = False):
- if not include_models:
- return self.read_database()
- return self.read(False, True)
- class CheckpointManager:
- def __init__(self, job_id: str = None, role: str = None, party_id: int = None,
- model_id: str = None, model_version: str = None,
- component_name: str = None, component_module_name: str = None,
- task_id: str = None, task_version: int = None,
- job_parameters: RunParameters = None,
- max_to_keep: int = None
- ):
- self.job_id = job_id
- self.role = role
- self.party_id = party_id
- self.model_id = model_id
- self.model_version = model_version
- self.party_model_id = gen_party_model_id(self.model_id, self.role, self.party_id)
- self.component_name = component_name if component_name else 'pipeline'
- self.module_name = component_module_name if component_module_name else 'Pipeline'
- self.task_id = task_id
- self.task_version = task_version
- self.job_parameters = job_parameters
- self.directory = (Path(get_fate_flow_directory()) / 'model_local_cache' /
- self.party_model_id / model_version / 'checkpoint' / self.component_name)
- if isinstance(max_to_keep, int):
- if max_to_keep <= 0:
- raise ValueError('max_to_keep must be positive')
- elif max_to_keep is not None:
- raise TypeError('max_to_keep must be an integer')
- self.checkpoints = deque(maxlen=max_to_keep)
- def load_checkpoints_from_disk(self):
- checkpoints = []
- for directory in self.directory.glob('*'):
- if not directory.is_dir() or '#' not in directory.name:
- continue
- step_index, step_name = directory.name.split('#', 1)
- checkpoint = Checkpoint(self.directory, int(step_index), step_name)
- if not checkpoint.available:
- continue
- checkpoints.append(checkpoint)
- self.checkpoints = deque(sorted(checkpoints, key=lambda i: i.step_index), self.max_checkpoints_number)
- @property
- def checkpoints_number(self):
- return len(self.checkpoints)
- @property
- def max_checkpoints_number(self):
- return self.checkpoints.maxlen
- @property
- def number_indexed_checkpoints(self):
- return OrderedDict((i.step_index, i) for i in self.checkpoints)
- @property
- def name_indexed_checkpoints(self):
- return OrderedDict((i.step_name, i) for i in self.checkpoints)
- def get_checkpoint_by_index(self, step_index: int):
- return self.number_indexed_checkpoints.get(step_index)
- def get_checkpoint_by_name(self, step_name: str):
- return self.name_indexed_checkpoints.get(step_name)
- @property
- def latest_checkpoint(self):
- if self.checkpoints:
- return self.checkpoints[-1]
- @property
- def latest_step_index(self):
- if self.latest_checkpoint is not None:
- return self.latest_checkpoint.step_index
- @property
- def latest_step_name(self):
- if self.latest_checkpoint is not None:
- return self.latest_checkpoint.step_name
- def new_checkpoint(self, step_index: int, step_name: str):
- if self.job_parameters is not None and self.job_parameters.job_type == 'predict':
- raise ValueError('Cannot create checkpoint in predict job.')
- popped_checkpoint = None
- if self.max_checkpoints_number and self.checkpoints_number >= self.max_checkpoints_number:
- popped_checkpoint = self.checkpoints[0]
- checkpoint = Checkpoint(self.directory, step_index, step_name)
- self.checkpoints.append(checkpoint)
- if popped_checkpoint is not None:
- popped_checkpoint.remove()
- return checkpoint
- def clean(self):
- self.checkpoints = deque(maxlen=self.max_checkpoints_number)
- rmtree(self.directory)
- # copy the checkpoint as a component model to the new model version
- def deploy(self, new_model_version: str, model_alias: str, step_index: int = None, step_name: str = None):
- if step_index is not None:
- checkpoint = self.get_checkpoint_by_index(step_index)
- elif step_name is not None:
- checkpoint = self.get_checkpoint_by_name(step_name)
- else:
- raise KeyError('step_index or step_name is required.')
- if checkpoint is None:
- raise TypeError('Checkpoint not found.')
- # check files hash
- checkpoint.read()
- directory = Path(get_fate_flow_directory()) / 'model_local_cache' / self.party_model_id / new_model_version
- target = directory / 'variables' / 'data' / self.component_name / model_alias
- locker = Locker(directory)
- with locker.lock:
- rmtree(target, True)
- copytree(checkpoint.directory, target,
- ignore=lambda src, names: {i for i in names if i.startswith('.')})
- for f in target.glob('*.pb'):
- f.replace(f.with_suffix(''))
- def to_dict(self, include_models: bool = False):
- return [checkpoint.to_dict(include_models) for checkpoint in self.checkpoints]
|