123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from copy import deepcopy
- from qcloud_cos import CosConfig, CosS3Client
- from qcloud_cos.cos_exception import CosServiceError
- from fate_flow.model.model_storage_base import ComponentStorageBase, ModelStorageBase
- from fate_flow.pipelined_model.pipelined_model import PipelinedModel
- from fate_flow.pipelined_model.pipelined_component import PipelinedComponent
- from fate_flow.utils.log_utils import getLogger
- LOGGER = getLogger()
- class TencentCOSModelStorage(ModelStorageBase):
- def store_key(self, model_id: str, model_version: str):
- return f'FATEFlow/PipelinedModel/{model_id}/{model_version}.zip'
- def exists(self, model_id: str, model_version: str, store_address: dict):
- store_key = self.store_key(model_id, model_version)
- cos = self.get_connection(store_address)
- try:
- cos.head_object(
- Bucket=store_address["Bucket"],
- Key=store_key,
- )
- except CosServiceError as e:
- if e.get_error_code() != 'NoSuchResource':
- raise e
- return False
- else:
- return True
- def store(self, model_id: str, model_version: str, store_address: dict, force_update: bool = False):
- """
- Store the model from local cache to cos
- :param model_id:
- :param model_version:
- :param store_address:
- :param force_update:
- :return:
- """
- store_key = self.store_key(model_id, model_version)
- if not force_update and self.exists(model_id, model_version, store_address):
- raise FileExistsError(f"The object {store_key} already exists.")
- model = PipelinedModel(model_id, model_version)
- cos = self.get_connection(store_address)
- try:
- hash_ = model.packaging_model()
- response = cos.upload_file(
- Bucket=store_address["Bucket"],
- LocalFilePath=model.archive_model_file_path,
- Key=store_key,
- EnableMD5=True,
- )
- except Exception as e:
- LOGGER.exception(e)
- raise Exception(f"Store model {model_id} {model_version} to Tencent COS failed.")
- else:
- LOGGER.info(f"Store model {model_id} {model_version} to Tencent COS successfully. "
- f"Archive path: {model.archive_model_file_path} Key: {store_key} ETag: {response['ETag']}")
- return hash_
- def restore(self, model_id: str, model_version: str, store_address: dict, force_update: bool = False, hash_: str = None):
- """
- Restore model from cos to local cache
- :param model_id:
- :param model_version:
- :param store_address:
- :return:
- """
- store_key = self.store_key(model_id, model_version)
- model = PipelinedModel(model_id, model_version)
- cos = self.get_connection(store_address)
- try:
- cos.download_file(
- Bucket=store_address["Bucket"],
- Key=store_key,
- DestFilePath=model.archive_model_file_path,
- EnableCRC=True,
- )
- model.unpack_model(model.archive_model_file_path, force_update, hash_)
- except Exception as e:
- LOGGER.exception(e)
- raise Exception(f"Restore model {model_id} {model_version} from Tencent COS failed.")
- else:
- LOGGER.info(f"Restore model {model_id} {model_version} from Tencent COS successfully. "
- f"Archive path: {model.archive_model_file_path} Key: {store_key}")
- @staticmethod
- def get_connection(store_address: dict):
- store_address = deepcopy(store_address)
- store_address.pop('storage', None)
- store_address.pop('Bucket')
- return CosS3Client(CosConfig(**store_address))
- class TencentCOSComponentStorage(ComponentStorageBase):
- def __init__(self, Region, SecretId, SecretKey, Bucket):
- self.client = CosS3Client(CosConfig(Region=Region, SecretId=SecretId, SecretKey=SecretKey))
- self.bucket = Bucket
- self.region = Region
- def get_key(self, party_model_id, model_version, component_name):
- return f'FATEFlow/PipelinedComponent/{party_model_id}/{model_version}/{component_name}.zip'
- def exists(self, party_model_id, model_version, component_name):
- key = self.get_key(party_model_id, model_version, component_name)
- try:
- self.client.head_object(
- Bucket=self.bucket,
- Key=key,
- )
- except CosServiceError as e:
- if e.get_error_code() != 'NoSuchResource':
- raise e
- return False
- else:
- return True
- def upload(self, party_model_id, model_version, component_name):
- pipelined_component = PipelinedComponent(party_model_id=party_model_id, model_version=model_version)
- filename, hash_ = pipelined_component.pack_component(component_name)
- self.client.upload_file(
- Bucket=self.bucket,
- LocalFilePath=filename,
- Key=self.get_key(party_model_id, model_version, component_name),
- EnableMD5=True,
- )
- return hash_
- def download(self, party_model_id, model_version, component_name, hash_=None):
- pipelined_component = PipelinedComponent(party_model_id=party_model_id, model_version=model_version)
- self.client.download_file(
- Bucket=self.bucket,
- Key=self.get_key(party_model_id, model_version, component_name),
- DestFilePath=pipelined_component.get_archive_path(component_name),
- EnableCRC=True,
- )
- pipelined_component.unpack_component(component_name, hash_)
- def copy(self, party_model_id, model_version, component_name, source_model_version):
- self.client.copy(
- Bucket=self.bucket,
- Key=self.get_key(party_model_id, model_version, component_name),
- CopySource={
- 'Bucket': self.bucket,
- 'Key': self.get_key(party_model_id, source_model_version, component_name),
- 'Region': self.region,
- },
- )
|