sync_model.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #
  2. # Copyright 2022 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from copy import deepcopy
  17. from hashlib import sha256
  18. from typing import Tuple
  19. from peewee import DoesNotExist
  20. from fate_flow.db.db_models import (
  21. DB, PipelineComponentMeta,
  22. MachineLearningModelInfo as MLModel,
  23. )
  24. from fate_flow.db.service_registry import ServerRegistry
  25. from fate_flow.model import (
  26. lock, model_storage_base,
  27. mysql_model_storage, tencent_cos_model_storage,
  28. )
  29. from fate_flow.pipelined_model import Pipelined
  30. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  31. from fate_flow.settings import HOST
  32. model_storage_map = {
  33. 'mysql': mysql_model_storage.MysqlModelStorage,
  34. 'tencent_cos': tencent_cos_model_storage.TencentCOSModelStorage,
  35. }
  36. component_storage_map = {
  37. 'mysql': mysql_model_storage.MysqlComponentStorage,
  38. 'tencent_cos': tencent_cos_model_storage.TencentCOSComponentStorage,
  39. }
  40. def get_storage(storage_map: dict) -> Tuple[model_storage_base.ModelStorageBase, dict]:
  41. store_address = deepcopy(ServerRegistry.MODEL_STORE_ADDRESS)
  42. store_type = store_address.pop('storage')
  43. if store_type not in storage_map:
  44. raise KeyError(f"Model storage '{store_type}' is not supported.")
  45. return storage_map[store_type], store_address
  46. class SyncModel(Pipelined):
  47. def __init__(self, **kwargs):
  48. super().__init__(**kwargs)
  49. self.pipelined_model = PipelinedModel(self.party_model_id, self.model_version)
  50. storage, storage_address = get_storage(model_storage_map)
  51. self.model_storage = storage()
  52. self.model_storage_parameters = {
  53. 'model_id': self.party_model_id,
  54. 'model_version': self.model_version,
  55. 'store_address': storage_address,
  56. }
  57. self.lock = DB.lock(
  58. sha256(
  59. '_'.join((
  60. 'sync_model',
  61. self.party_model_id,
  62. self.model_version,
  63. )).encode('utf-8')
  64. ).hexdigest(),
  65. -1,
  66. )
  67. @DB.connection_context()
  68. def db_exists(self):
  69. try:
  70. self.get_model()
  71. except DoesNotExist:
  72. return False
  73. else:
  74. return True
  75. def local_exists(self):
  76. return self.pipelined_model.exists()
  77. def remote_exists(self):
  78. return self.model_storage.exists(**self.model_storage_parameters)
  79. def get_model(self):
  80. return MLModel.get(
  81. MLModel.f_role == self.role,
  82. MLModel.f_party_id == self.party_id,
  83. MLModel.f_model_id == self.model_id,
  84. MLModel.f_model_version == self.model_version,
  85. )
  86. @DB.connection_context()
  87. def upload(self, force_update=False):
  88. if self.remote_exists() and not force_update:
  89. return
  90. with self.lock:
  91. model = self.get_model()
  92. hash_ = self.model_storage.store(
  93. force_update=force_update,
  94. **self.model_storage_parameters,
  95. )
  96. model.f_archive_sha256 = hash_
  97. model.f_archive_from_ip = HOST
  98. model.save()
  99. return model
  100. @DB.connection_context()
  101. def download(self, force_update=False):
  102. if self.local_exists() and not force_update:
  103. return
  104. with self.lock:
  105. model = self.get_model()
  106. self.model_storage.restore(
  107. force_update=force_update, hash_=model.f_archive_sha256,
  108. **self.model_storage_parameters,
  109. )
  110. return model
  111. class SyncComponent(Pipelined):
  112. def __init__(self, *, component_name, **kwargs):
  113. super().__init__(**kwargs)
  114. self.component_name = component_name
  115. self.pipelined_model = PipelinedModel(self.party_model_id, self.model_version)
  116. storage, storage_address = get_storage(component_storage_map)
  117. self.component_storage = storage(**storage_address)
  118. self.component_storage_parameters = (
  119. self.party_model_id,
  120. self.model_version,
  121. self.component_name,
  122. )
  123. self.query_args = (
  124. PipelineComponentMeta.f_role == self.role,
  125. PipelineComponentMeta.f_party_id == self.party_id,
  126. PipelineComponentMeta.f_model_id == self.model_id,
  127. PipelineComponentMeta.f_model_version == self.model_version,
  128. PipelineComponentMeta.f_component_name == self.component_name,
  129. )
  130. self.lock = DB.lock(
  131. sha256(
  132. '_'.join((
  133. 'sync_component',
  134. self.party_model_id,
  135. self.model_version,
  136. self.component_name,
  137. )).encode('utf-8')
  138. ).hexdigest(),
  139. -1,
  140. )
  141. @DB.connection_context()
  142. def db_exists(self):
  143. return PipelineComponentMeta.select().where(*self.query_args).count() > 0
  144. def local_exists(self):
  145. return self.pipelined_model.pipelined_component.exists(self.component_name)
  146. def remote_exists(self):
  147. with self.component_storage as storage:
  148. return storage.exists(*self.component_storage_parameters)
  149. def get_archive_hash(self):
  150. query = tuple(PipelineComponentMeta.select().where(*self.query_args).group_by(
  151. PipelineComponentMeta.f_archive_sha256, PipelineComponentMeta.f_archive_from_ip))
  152. if len(query) != 1:
  153. raise ValueError(f'The define_meta data of {self.component_name} in database is invalid.')
  154. return query[0].f_archive_sha256
  155. def update_archive_hash(self, hash_):
  156. PipelineComponentMeta.update(
  157. f_archive_sha256=hash_,
  158. f_archive_from_ip=HOST,
  159. ).where(*self.query_args).execute()
  160. @DB.connection_context()
  161. @lock
  162. def upload(self):
  163. # check the data in database
  164. self.get_archive_hash()
  165. with self.component_storage as storage:
  166. hash_ = storage.upload(*self.component_storage_parameters)
  167. self.update_archive_hash(hash_)
  168. @DB.connection_context()
  169. @lock
  170. def download(self):
  171. hash_ = self.get_archive_hash()
  172. with self.component_storage as storage:
  173. storage.download(*self.component_storage_parameters, hash_)
  174. @DB.connection_context()
  175. @lock
  176. def copy(self, source_model_version, hash_):
  177. with self.component_storage as storage:
  178. storage.copy(*self.component_storage_parameters, source_model_version)
  179. self.update_archive_hash(hash_)