123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- #
- # Copyright 2022 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the 'License');
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an 'AS IS' BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import hashlib
- import os
- from pathlib import Path
- from zipfile import ZipFile
- from ruamel import yaml
- from fate_arch.common.base_utils import json_dumps, json_loads
- from fate_flow.db.db_models import DB, PipelineComponentMeta
- from fate_flow.db.db_utils import bulk_insert_into_db
- from fate_flow.model import Locker, local_cache_required, lock
- from fate_flow.pipelined_model import Pipelined
- from fate_flow.settings import TEMP_DIRECTORY
- from fate_flow.utils.base_utils import get_fate_flow_directory
- class PipelinedComponent(Pipelined, Locker):
- def __init__(self, **kwargs):
- Pipelined.__init__(self, **kwargs)
- self.model_path = Path(get_fate_flow_directory('model_local_cache'), self.party_model_id, self.model_version)
- self.define_meta_path = self.model_path / 'define' / 'define_meta.yaml'
- self.variables_data_path = self.model_path / 'variables' / 'data'
- self.run_parameters_path = self.model_path / 'run_parameters'
- self.checkpoint_path = self.model_path / 'checkpoint'
- self.query_args = (
- PipelineComponentMeta.f_model_id == self.model_id,
- PipelineComponentMeta.f_model_version == self.model_version,
- PipelineComponentMeta.f_role == self.role,
- PipelineComponentMeta.f_party_id == self.party_id,
- )
- Locker.__init__(self, self.model_path)
- def exists(self, component_name=None, model_alias=None):
- if component_name is None:
- return self.model_path.is_dir() and set(os.listdir(self.model_path)) - {'.lock'}
- query = self.get_define_meta_from_db(PipelineComponentMeta.f_component_name == component_name)
- if query:
- query = query[0]
- if model_alias is None:
- model_alias = query.f_model_alias
- model_proto_index = query.f_model_proto_index
- else:
- query = self.get_define_meta_from_file()
- try:
- query = query['model_proto'][component_name]
- except KeyError:
- return False
- if model_alias is None:
- if len(query) != 1:
- return False
- model_alias = next(iter(query.keys()))
- try:
- model_proto_index = query[model_alias]
- except KeyError:
- return False
- if not model_proto_index:
- return False
- variables_data_path = self.variables_data_path / component_name / model_alias
- for model_name, buffer_name in model_proto_index.items():
- if not (variables_data_path / model_name).is_file():
- return False
- return True
- def get_define_meta_from_file(self):
- if not self.define_meta_path.is_file():
- return {}
- return yaml.safe_load(self.define_meta_path.read_text('utf-8'))
- @DB.connection_context()
- def get_define_meta_from_db(self, *query_args):
- return tuple(PipelineComponentMeta.select().where(*self.query_args, *query_args))
- def rearrange_define_meta(self, data):
- define_meta = {
- 'component_define': {},
- 'model_proto': {},
- }
- for row in data:
- define_meta['component_define'][row.f_component_name] = {'module_name': row.f_component_module_name}
- # there is only one model_alias in a component
- if row.f_component_name not in define_meta['model_proto']:
- define_meta['model_proto'][row.f_component_name] = {}
- define_meta['model_proto'][row.f_component_name][row.f_model_alias] = row.f_model_proto_index
- return define_meta
- def get_define_meta(self):
- query = self.get_define_meta_from_db()
- return self.rearrange_define_meta(query) if query else self.get_define_meta_from_file()
- @DB.connection_context()
- def save_define_meta(self, component_name, component_module_name, model_alias, model_proto_index, run_parameters):
- PipelineComponentMeta.insert(
- f_model_id=self.model_id,
- f_model_version=self.model_version,
- f_role=self.role,
- f_party_id=self.party_id,
- f_component_name=component_name,
- f_component_module_name=component_module_name,
- f_model_alias=model_alias,
- f_model_proto_index=model_proto_index,
- f_run_parameters=run_parameters,
- ).on_conflict(preserve=(
- PipelineComponentMeta.f_update_time,
- PipelineComponentMeta.f_update_date,
- PipelineComponentMeta.f_component_module_name,
- PipelineComponentMeta.f_model_alias,
- PipelineComponentMeta.f_model_proto_index,
- PipelineComponentMeta.f_run_parameters,
- )).execute()
- @lock
- def save_define_meta_from_db_to_file(self):
- query = self.get_define_meta_from_db()
- for row in query:
- run_parameters_path = self.get_run_parameters_path(row.f_component_name)
- run_parameters_path.parent.mkdir(parents=True, exist_ok=True)
- with run_parameters_path.open('w', encoding='utf-8') as f:
- f.write(json_dumps(row.f_run_parameters))
- self.define_meta_path.parent.mkdir(parents=True, exist_ok=True)
- with self.define_meta_path.open('w', encoding='utf-8') as f:
- yaml.dump(self.rearrange_define_meta(query), f, Dumper=yaml.RoundTripDumper)
- # import model
- @local_cache_required(True)
- def save_define_meta_from_file_to_db(self, replace_on_conflict=False):
- if not replace_on_conflict:
- with DB.connection_context():
- count = PipelineComponentMeta.select().where(*self.query_args).count()
- if count > 0:
- raise ValueError(f'The define_meta data already exists in database.')
- define_meta = self.get_define_meta_from_file()
- run_parameters = self.get_run_parameters_from_files()
- insert = []
- for component_name, component_define in define_meta['component_define'].items():
- for model_alias, model_proto_index in define_meta['model_proto'][component_name].items():
- row = {
- 'f_model_id': self.model_id,
- 'f_model_version': self.model_version,
- 'f_role': self.role,
- 'f_party_id': self.party_id,
- 'f_component_name': component_name,
- 'f_component_module_name': component_define['module_name'],
- 'f_model_alias': model_alias,
- 'f_model_proto_index': model_proto_index,
- 'f_run_parameters': run_parameters.get(component_name, {}),
- }
- insert.append(row)
- bulk_insert_into_db(PipelineComponentMeta, insert, replace_on_conflict)
- def replicate_define_meta(self, modification, query_args=(), replace_on_conflict=False):
- query = self.get_define_meta_from_db(*query_args)
- if not query:
- return
- insert = []
- for row in query:
- row = row.to_dict()
- del row['id']
- row.update(modification)
- insert.append(row)
- bulk_insert_into_db(PipelineComponentMeta, insert, replace_on_conflict)
- def get_run_parameters_path(self, component_name):
- return self.run_parameters_path / component_name / 'run_parameters.json'
- @lock
- def get_run_parameters_from_files(self):
- if not self.run_parameters_path.is_dir():
- return {}
- return {
- path.name: json_loads(self.get_run_parameters_path(path.name).read_text('utf-8'))
- for path in self.run_parameters_path.iterdir()
- }
- def get_run_parameters(self):
- query = self.get_define_meta_from_db()
- return {
- row.f_component_name: row.f_run_parameters
- for row in query
- } if query else self.get_run_parameters_from_files()
- def get_archive_path(self, component_name):
- return Path(TEMP_DIRECTORY, f'{self.party_model_id}_{self.model_version}_{component_name}.zip')
- def walk_component(self, zip_file, path: Path):
- if path.is_dir():
- for subpath in path.iterdir():
- self.walk_component(zip_file, subpath)
- elif path.is_file():
- zip_file.write(path, path.relative_to(self.model_path))
- @local_cache_required(True)
- def pack_component(self, component_name):
- filename = self.get_archive_path(component_name)
- with ZipFile(filename, 'w') as zip_file:
- self.walk_component(zip_file, self.variables_data_path / component_name)
- self.walk_component(zip_file, self.checkpoint_path / component_name)
- hash_ = hashlib.sha256(filename.read_bytes()).hexdigest()
- return filename, hash_
- @lock
- def unpack_component(self, component_name, hash_=None):
- filename = self.get_archive_path(component_name)
- if hash_ is not None:
- sha256 = hashlib.sha256(filename.read_bytes()).hexdigest()
- if hash_ != sha256:
- raise ValueError(f'Model archive hash mismatch. path: {filename} expected: {hash_} actual: {sha256}')
- with ZipFile(filename, 'r') as zip_file:
- zip_file.extractall(self.model_path)
|