123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- #
- # Copyright 2022 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from copy import deepcopy
- from hashlib import sha256
- from typing import Tuple
- from peewee import DoesNotExist
- from fate_flow.db.db_models import (
- DB, PipelineComponentMeta,
- MachineLearningModelInfo as MLModel,
- )
- from fate_flow.db.service_registry import ServerRegistry
- from fate_flow.model import (
- lock, model_storage_base,
- mysql_model_storage, tencent_cos_model_storage,
- )
- from fate_flow.pipelined_model import Pipelined
- from fate_flow.pipelined_model.pipelined_model import PipelinedModel
- from fate_flow.settings import HOST
- model_storage_map = {
- 'mysql': mysql_model_storage.MysqlModelStorage,
- 'tencent_cos': tencent_cos_model_storage.TencentCOSModelStorage,
- }
- component_storage_map = {
- 'mysql': mysql_model_storage.MysqlComponentStorage,
- 'tencent_cos': tencent_cos_model_storage.TencentCOSComponentStorage,
- }
- def get_storage(storage_map: dict) -> Tuple[model_storage_base.ModelStorageBase, dict]:
- store_address = deepcopy(ServerRegistry.MODEL_STORE_ADDRESS)
- store_type = store_address.pop('storage')
- if store_type not in storage_map:
- raise KeyError(f"Model storage '{store_type}' is not supported.")
- return storage_map[store_type], store_address
- class SyncModel(Pipelined):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.pipelined_model = PipelinedModel(self.party_model_id, self.model_version)
- storage, storage_address = get_storage(model_storage_map)
- self.model_storage = storage()
- self.model_storage_parameters = {
- 'model_id': self.party_model_id,
- 'model_version': self.model_version,
- 'store_address': storage_address,
- }
- self.lock = DB.lock(
- sha256(
- '_'.join((
- 'sync_model',
- self.party_model_id,
- self.model_version,
- )).encode('utf-8')
- ).hexdigest(),
- -1,
- )
- @DB.connection_context()
- def db_exists(self):
- try:
- self.get_model()
- except DoesNotExist:
- return False
- else:
- return True
- def local_exists(self):
- return self.pipelined_model.exists()
- def remote_exists(self):
- return self.model_storage.exists(**self.model_storage_parameters)
- def get_model(self):
- return MLModel.get(
- MLModel.f_role == self.role,
- MLModel.f_party_id == self.party_id,
- MLModel.f_model_id == self.model_id,
- MLModel.f_model_version == self.model_version,
- )
- @DB.connection_context()
- def upload(self, force_update=False):
- if self.remote_exists() and not force_update:
- return
- with self.lock:
- model = self.get_model()
- hash_ = self.model_storage.store(
- force_update=force_update,
- **self.model_storage_parameters,
- )
- model.f_archive_sha256 = hash_
- model.f_archive_from_ip = HOST
- model.save()
- return model
- @DB.connection_context()
- def download(self, force_update=False):
- if self.local_exists() and not force_update:
- return
- with self.lock:
- model = self.get_model()
- self.model_storage.restore(
- force_update=force_update, hash_=model.f_archive_sha256,
- **self.model_storage_parameters,
- )
- return model
- class SyncComponent(Pipelined):
- def __init__(self, *, component_name, **kwargs):
- super().__init__(**kwargs)
- self.component_name = component_name
- self.pipelined_model = PipelinedModel(self.party_model_id, self.model_version)
- storage, storage_address = get_storage(component_storage_map)
- self.component_storage = storage(**storage_address)
- self.component_storage_parameters = (
- self.party_model_id,
- self.model_version,
- self.component_name,
- )
- self.query_args = (
- PipelineComponentMeta.f_role == self.role,
- PipelineComponentMeta.f_party_id == self.party_id,
- PipelineComponentMeta.f_model_id == self.model_id,
- PipelineComponentMeta.f_model_version == self.model_version,
- PipelineComponentMeta.f_component_name == self.component_name,
- )
- self.lock = DB.lock(
- sha256(
- '_'.join((
- 'sync_component',
- self.party_model_id,
- self.model_version,
- self.component_name,
- )).encode('utf-8')
- ).hexdigest(),
- -1,
- )
- @DB.connection_context()
- def db_exists(self):
- return PipelineComponentMeta.select().where(*self.query_args).count() > 0
- def local_exists(self):
- return self.pipelined_model.pipelined_component.exists(self.component_name)
- def remote_exists(self):
- with self.component_storage as storage:
- return storage.exists(*self.component_storage_parameters)
- def get_archive_hash(self):
- query = tuple(PipelineComponentMeta.select().where(*self.query_args).group_by(
- PipelineComponentMeta.f_archive_sha256, PipelineComponentMeta.f_archive_from_ip))
- if len(query) != 1:
- raise ValueError(f'The define_meta data of {self.component_name} in database is invalid.')
- return query[0].f_archive_sha256
- def update_archive_hash(self, hash_):
- PipelineComponentMeta.update(
- f_archive_sha256=hash_,
- f_archive_from_ip=HOST,
- ).where(*self.query_args).execute()
- @DB.connection_context()
- @lock
- def upload(self):
- # check the data in database
- self.get_archive_hash()
- with self.component_storage as storage:
- hash_ = storage.upload(*self.component_storage_parameters)
- self.update_archive_hash(hash_)
- @DB.connection_context()
- @lock
- def download(self):
- hash_ = self.get_archive_hash()
- with self.component_storage as storage:
- storage.download(*self.component_storage_parameters, hash_)
- @DB.connection_context()
- @lock
- def copy(self, source_model_version, hash_):
- with self.component_storage as storage:
- storage.copy(*self.component_storage_parameters, source_model_version)
- self.update_archive_hash(hash_)
|