tencent_cos_model_storage.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from copy import deepcopy
  17. from qcloud_cos import CosConfig, CosS3Client
  18. from qcloud_cos.cos_exception import CosServiceError
  19. from fate_flow.model.model_storage_base import ComponentStorageBase, ModelStorageBase
  20. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  21. from fate_flow.pipelined_model.pipelined_component import PipelinedComponent
  22. from fate_flow.utils.log_utils import getLogger
  23. LOGGER = getLogger()
  24. class TencentCOSModelStorage(ModelStorageBase):
  25. def store_key(self, model_id: str, model_version: str):
  26. return f'FATEFlow/PipelinedModel/{model_id}/{model_version}.zip'
  27. def exists(self, model_id: str, model_version: str, store_address: dict):
  28. store_key = self.store_key(model_id, model_version)
  29. cos = self.get_connection(store_address)
  30. try:
  31. cos.head_object(
  32. Bucket=store_address["Bucket"],
  33. Key=store_key,
  34. )
  35. except CosServiceError as e:
  36. if e.get_error_code() != 'NoSuchResource':
  37. raise e
  38. return False
  39. else:
  40. return True
  41. def store(self, model_id: str, model_version: str, store_address: dict, force_update: bool = False):
  42. """
  43. Store the model from local cache to cos
  44. :param model_id:
  45. :param model_version:
  46. :param store_address:
  47. :param force_update:
  48. :return:
  49. """
  50. store_key = self.store_key(model_id, model_version)
  51. if not force_update and self.exists(model_id, model_version, store_address):
  52. raise FileExistsError(f"The object {store_key} already exists.")
  53. model = PipelinedModel(model_id, model_version)
  54. cos = self.get_connection(store_address)
  55. try:
  56. hash_ = model.packaging_model()
  57. response = cos.upload_file(
  58. Bucket=store_address["Bucket"],
  59. LocalFilePath=model.archive_model_file_path,
  60. Key=store_key,
  61. EnableMD5=True,
  62. )
  63. except Exception as e:
  64. LOGGER.exception(e)
  65. raise Exception(f"Store model {model_id} {model_version} to Tencent COS failed.")
  66. else:
  67. LOGGER.info(f"Store model {model_id} {model_version} to Tencent COS successfully. "
  68. f"Archive path: {model.archive_model_file_path} Key: {store_key} ETag: {response['ETag']}")
  69. return hash_
  70. def restore(self, model_id: str, model_version: str, store_address: dict, force_update: bool = False, hash_: str = None):
  71. """
  72. Restore model from cos to local cache
  73. :param model_id:
  74. :param model_version:
  75. :param store_address:
  76. :return:
  77. """
  78. store_key = self.store_key(model_id, model_version)
  79. model = PipelinedModel(model_id, model_version)
  80. cos = self.get_connection(store_address)
  81. try:
  82. cos.download_file(
  83. Bucket=store_address["Bucket"],
  84. Key=store_key,
  85. DestFilePath=model.archive_model_file_path,
  86. EnableCRC=True,
  87. )
  88. model.unpack_model(model.archive_model_file_path, force_update, hash_)
  89. except Exception as e:
  90. LOGGER.exception(e)
  91. raise Exception(f"Restore model {model_id} {model_version} from Tencent COS failed.")
  92. else:
  93. LOGGER.info(f"Restore model {model_id} {model_version} from Tencent COS successfully. "
  94. f"Archive path: {model.archive_model_file_path} Key: {store_key}")
  95. @staticmethod
  96. def get_connection(store_address: dict):
  97. store_address = deepcopy(store_address)
  98. store_address.pop('storage', None)
  99. store_address.pop('Bucket')
  100. return CosS3Client(CosConfig(**store_address))
  101. class TencentCOSComponentStorage(ComponentStorageBase):
  102. def __init__(self, Region, SecretId, SecretKey, Bucket):
  103. self.client = CosS3Client(CosConfig(Region=Region, SecretId=SecretId, SecretKey=SecretKey))
  104. self.bucket = Bucket
  105. self.region = Region
  106. def get_key(self, party_model_id, model_version, component_name):
  107. return f'FATEFlow/PipelinedComponent/{party_model_id}/{model_version}/{component_name}.zip'
  108. def exists(self, party_model_id, model_version, component_name):
  109. key = self.get_key(party_model_id, model_version, component_name)
  110. try:
  111. self.client.head_object(
  112. Bucket=self.bucket,
  113. Key=key,
  114. )
  115. except CosServiceError as e:
  116. if e.get_error_code() != 'NoSuchResource':
  117. raise e
  118. return False
  119. else:
  120. return True
  121. def upload(self, party_model_id, model_version, component_name):
  122. pipelined_component = PipelinedComponent(party_model_id=party_model_id, model_version=model_version)
  123. filename, hash_ = pipelined_component.pack_component(component_name)
  124. self.client.upload_file(
  125. Bucket=self.bucket,
  126. LocalFilePath=filename,
  127. Key=self.get_key(party_model_id, model_version, component_name),
  128. EnableMD5=True,
  129. )
  130. return hash_
  131. def download(self, party_model_id, model_version, component_name, hash_=None):
  132. pipelined_component = PipelinedComponent(party_model_id=party_model_id, model_version=model_version)
  133. self.client.download_file(
  134. Bucket=self.bucket,
  135. Key=self.get_key(party_model_id, model_version, component_name),
  136. DestFilePath=pipelined_component.get_archive_path(component_name),
  137. EnableCRC=True,
  138. )
  139. pipelined_component.unpack_component(component_name, hash_)
  140. def copy(self, party_model_id, model_version, component_name, source_model_version):
  141. self.client.copy(
  142. Bucket=self.bucket,
  143. Key=self.get_key(party_model_id, model_version, component_name),
  144. CopySource={
  145. 'Bucket': self.bucket,
  146. 'Key': self.get_key(party_model_id, source_model_version, component_name),
  147. 'Region': self.region,
  148. },
  149. )