123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import base64
- import hashlib
- import json
- import os
- import shutil
- import typing
- from google.protobuf import json_format
- from fate_arch.common.base_utils import json_dumps, json_loads
- from fate_flow.component_env_utils import provider_utils
- from fate_flow.db.runtime_config import RuntimeConfig
- from fate_flow.model import (
- Locker, local_cache_required,
- lock, parse_proto_object,
- )
- from fate_flow.pipelined_model.pipelined_component import PipelinedComponent
- from fate_flow.protobuf.python.pipeline_pb2 import Pipeline
- from fate_flow.settings import TEMP_DIRECTORY, stat_logger
- from fate_flow.utils.job_utils import (
- PIPELINE_COMPONENT_NAME, PIPELINE_MODEL_ALIAS,
- PIPELINE_COMPONENT_MODULE_NAME, PIPELINE_MODEL_NAME,
- )
- from fate_flow.utils.base_utils import get_fate_flow_directory
- class PipelinedModel(Locker):
- def __init__(self, model_id, model_version):
- """
- Support operations on FATE PipelinedModels
- :param model_id: the model id stored at the local party.
- :param model_version: the model version.
- """
- os.makedirs(TEMP_DIRECTORY, exist_ok=True)
- self.role, self.party_id, self._model_id = model_id.split('#', 2)
- self.party_model_id = self.model_id = model_id
- self.model_version = model_version
- self.pipelined_component = PipelinedComponent(role=self.role, party_id=self.party_id,
- model_id=self._model_id, model_version=self.model_version)
- self.model_path = self.pipelined_component.model_path
- super().__init__(self.model_path)
- def save_pipeline_model(self, pipeline_buffer_object, save_define_meta_file=True):
- model_buffers = {
- PIPELINE_MODEL_NAME: (
- type(pipeline_buffer_object).__name__,
- pipeline_buffer_object.SerializeToString(),
- json_format.MessageToDict(pipeline_buffer_object, including_default_value_fields=True),
- ),
- }
- self.save_component_model(PIPELINE_COMPONENT_NAME, PIPELINE_COMPONENT_MODULE_NAME, PIPELINE_MODEL_ALIAS, model_buffers)
- # only update pipeline model file if save_define_meta_file is False
- if save_define_meta_file:
- self.pipelined_component.save_define_meta_from_db_to_file()
- def save_component_model(self, *args, **kwargs):
- component_model = self.create_component_model(*args, **kwargs)
- self.write_component_model(component_model)
- def create_component_model(self, component_name, component_module_name, model_alias,
- model_buffers: typing.Dict[str, typing.Tuple[str, bytes, dict]],
- user_specified_run_parameters: dict = None):
- component_model = {"buffer": {}}
- component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, component_name, model_alias)
- model_proto_index = {}
- for model_name, (proto_index, object_serialized, object_json) in model_buffers.items():
- storage_path = os.path.join(component_model_storage_path, model_name).replace(get_fate_flow_directory(), "")
- component_model["buffer"][storage_path] = (base64.b64encode(object_serialized).decode(), object_json)
- model_proto_index[model_name] = proto_index # index of model name and proto buffer class name
- stat_logger.info(f"saved {component_name} {model_alias} {model_name} buffer")
- component_model["component_name"] = component_name
- component_model["component_module_name"] = component_module_name
- component_model["model_alias"] = model_alias
- component_model["model_proto_index"] = model_proto_index
- component_model["run_parameters"] = user_specified_run_parameters
- return component_model
- @lock
- def write_component_model(self, component_model):
- for storage_path, (object_serialized_encoded, object_json) in component_model.get("buffer").items():
- storage_path = get_fate_flow_directory() + storage_path
- os.makedirs(os.path.dirname(storage_path), exist_ok=True)
- with open(storage_path, "wb") as fw:
- fw.write(base64.b64decode(object_serialized_encoded.encode()))
- with open(f"{storage_path}.json", "w", encoding="utf8") as fw:
- fw.write(json_dumps(object_json))
- self.pipelined_component.save_define_meta(
- component_model["component_name"], component_model["component_module_name"],
- component_model["model_alias"], component_model["model_proto_index"],
- component_model.get("run_parameters") or {},
- )
- stat_logger.info(f'saved {component_model["component_name"]} {component_model["model_alias"]} successfully')
- @local_cache_required(True)
- def _read_component_model(self, component_name, model_alias):
- component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, component_name, model_alias)
- model_proto_index = self.get_model_proto_index(component_name=component_name, model_alias=model_alias)
- model_buffers = {}
- for model_name, buffer_name in model_proto_index.items():
- storage_path = os.path.join(component_model_storage_path, model_name)
- with open(storage_path, "rb") as f:
- buffer_object_serialized_string = f.read()
- try:
- with open(f"{storage_path}.json", encoding="utf-8") as f:
- buffer_object_json_format = json_loads(f.read())
- except FileNotFoundError:
- buffer_object_json_format = ""
- # TODO: should be running in worker
- """
- buffer_object_json_format = json_format.MessageToDict(
- parse_proto_object(buffer_name, buffer_object_serialized_string),
- including_default_value_fields=True
- )
- with open(f"{storage_path}.json", "x", encoding="utf-8") as f:
- f.write(json_dumps(buffer_object_json_format))
- """
- model_buffers[model_name] = (
- buffer_name,
- buffer_object_serialized_string,
- buffer_object_json_format,
- )
- return model_buffers
- # TODO: use different functions instead of passing arguments
- def read_component_model(self, component_name, model_alias=None, parse=True, output_json=False):
- if model_alias is None:
- model_alias = self.get_model_alias(component_name)
- if not self.pipelined_component.exists(component_name, model_alias):
- return {}
- _model_buffers = self._read_component_model(component_name, model_alias)
- model_buffers = {}
- for model_name, (
- buffer_name,
- buffer_object_serialized_string,
- buffer_object_json_format,
- ) in _model_buffers.items():
- if output_json:
- model_buffers[model_name] = buffer_object_json_format
- elif parse:
- model_buffers[model_name] = parse_proto_object(buffer_name, buffer_object_serialized_string)
- else:
- model_buffers[model_name] = [
- buffer_name,
- base64.b64encode(buffer_object_serialized_string).decode(),
- ]
- return model_buffers
- # TODO: integration with read_component_model
- @local_cache_required(True)
- def read_pipeline_model(self, parse=True):
- component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, PIPELINE_COMPONENT_NAME, PIPELINE_MODEL_ALIAS)
- model_proto_index = self.get_model_proto_index(PIPELINE_COMPONENT_NAME, PIPELINE_MODEL_ALIAS)
- model_buffers = {}
- for model_name, buffer_name in model_proto_index.items():
- with open(os.path.join(component_model_storage_path, model_name), "rb") as fr:
- buffer_object_serialized_string = fr.read()
- model_buffers[model_name] = (parse_proto_object(buffer_name, buffer_object_serialized_string, Pipeline) if parse
- else [buffer_name, base64.b64encode(buffer_object_serialized_string).decode()])
- return model_buffers[PIPELINE_MODEL_NAME]
- @local_cache_required(True)
- def collect_models(self, in_bytes=False, b64encode=True):
- define_meta = self.pipelined_component.get_define_meta()
- model_buffers = {}
- for component_name in define_meta.get("model_proto", {}).keys():
- for model_alias, model_proto_index in define_meta["model_proto"][component_name].items():
- component_model_storage_path = os.path.join(self.pipelined_component.variables_data_path, component_name, model_alias)
- for model_name, buffer_name in model_proto_index.items():
- with open(os.path.join(component_model_storage_path, model_name), "rb") as fr:
- serialized_string = fr.read()
- if in_bytes:
- if b64encode:
- serialized_string = base64.b64encode(serialized_string).decode()
- model_buffers[f"{component_name}.{model_alias}:{model_name}"] = serialized_string
- else:
- model_buffers[model_name] = parse_proto_object(buffer_name, serialized_string)
- return model_buffers
- @staticmethod
- def get_model_migrate_tool():
- return provider_utils.get_provider_class_object(RuntimeConfig.COMPONENT_PROVIDER, "model_migrate", True)
- @staticmethod
- def get_homo_model_convert_tool():
- return provider_utils.get_provider_class_object(RuntimeConfig.COMPONENT_PROVIDER, "homo_model_convert", True)
- def exists(self):
- return self.pipelined_component.exists()
- @local_cache_required(True)
- def packaging_model(self):
- self.gen_model_import_config()
- # self.archive_model_file_path
- shutil.make_archive(self.archive_model_base_path, 'zip', self.model_path)
- with open(self.archive_model_file_path, 'rb') as f:
- hash_ = hashlib.sha256(f.read()).hexdigest()
- stat_logger.info(f'Make model {self.model_id} {self.model_version} archive successfully. '
- f'path: {self.archive_model_file_path} hash: {hash_}')
- return hash_
- @lock
- def unpack_model(self, archive_file_path: str, force_update: bool = False, hash_: str = None):
- if self.exists() and not force_update:
- raise FileExistsError(f'Model {self.model_id} {self.model_version} local cache already existed.')
- if hash_ is not None:
- with open(archive_file_path, 'rb') as f:
- sha256 = hashlib.sha256(f.read()).hexdigest()
- if hash_ != sha256:
- raise ValueError(f'Model archive hash mismatch. '
- f'path: {archive_file_path} expected: {hash_} actual: {sha256}')
- shutil.unpack_archive(archive_file_path, self.model_path, 'zip')
- stat_logger.info(f'Unpack model {self.model_id} {self.model_version} archive successfully. path: {self.model_path}')
- def get_component_define(self, component_name=None):
- component_define = self.pipelined_component.get_define_meta()['component_define']
- if component_name is None:
- return component_define
- return component_define.get(component_name, {})
- def get_model_proto_index(self, component_name=None, model_alias=None):
- model_proto = self.pipelined_component.get_define_meta()['model_proto']
- if component_name is None:
- return model_proto
- model_proto = model_proto.get(component_name, {})
- if model_alias is None:
- return model_proto
- return model_proto.get(model_alias, {})
- def get_model_alias(self, component_name):
- model_proto_index = self.get_model_proto_index(component_name)
- if len(model_proto_index) != 1:
- raise KeyError('Failed to detect "model_alias", please specify it manually.')
- return next(iter(model_proto_index.keys()))
- @property
- def archive_model_base_path(self):
- return os.path.join(TEMP_DIRECTORY, f'{self.party_model_id}_{self.model_version}')
- @property
- def archive_model_file_path(self):
- return f'{self.archive_model_base_path}.zip'
- @local_cache_required(True)
- def calculate_model_file_size(self):
- size = 0
- for root, dirs, files in os.walk(self.model_path):
- size += sum([os.path.getsize(os.path.join(root, name)) for name in files])
- return round(size/1024)
- @local_cache_required(True)
- def gen_model_import_config(self):
- config = {
- 'role': self.role,
- 'party_id': int(self.party_id),
- 'model_id': self._model_id,
- 'model_version': self.model_version,
- 'file': self.archive_model_file_path,
- 'force_update': False,
- }
- with (self.model_path / 'import_model.json').open('w', encoding='utf-8') as f:
- json.dump(config, f, indent=4)
|