mysql_model_storage.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the 'License');
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an 'AS IS' BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from re import I
  17. import sys
  18. from copy import deepcopy
  19. from peewee import (
  20. BigIntegerField, CharField, CompositeKey,
  21. IntegerField, PeeweeException, Value,
  22. )
  23. from playhouse.pool import PooledMySQLDatabase
  24. from fate_arch.common.base_utils import (
  25. current_timestamp, deserialize_b64,
  26. serialize_b64, timestamp_to_date,
  27. )
  28. from fate_arch.common.conf_utils import decrypt_database_password, decrypt_database_config
  29. from fate_arch.metastore.base_model import LongTextField
  30. from fate_flow.db.db_models import DataBaseModel
  31. from fate_flow.model.model_storage_base import ComponentStorageBase, ModelStorageBase
  32. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  33. from fate_flow.pipelined_model.pipelined_component import PipelinedComponent
  34. from fate_flow.utils.log_utils import getLogger
  35. LOGGER = getLogger()
  36. DB = PooledMySQLDatabase(None)
  37. SLICE_MAX_SIZE = 1024*1024*8
  38. class MysqlModelStorage(ModelStorageBase):
  39. def exists(self, model_id: str, model_version: str, store_address: dict):
  40. self.get_connection(store_address)
  41. try:
  42. with DB.connection_context():
  43. counts = MachineLearningModel.select().where(
  44. MachineLearningModel.f_model_id == model_id,
  45. MachineLearningModel.f_model_version == model_version,
  46. ).count()
  47. return counts > 0
  48. except PeeweeException as e:
  49. # Table doesn't exist
  50. if e.args and e.args[0] == 1146:
  51. return False
  52. raise e
  53. finally:
  54. self.close_connection()
  55. def store(self, model_id: str, model_version: str, store_address: dict, force_update: bool = False):
  56. '''
  57. Store the model from local cache to mysql
  58. :param model_id:
  59. :param model_version:
  60. :param store_address:
  61. :param force_update:
  62. :return:
  63. '''
  64. if not force_update and self.exists(model_id, model_version, store_address):
  65. raise FileExistsError(f'The model {model_id} {model_version} already exists in the database.')
  66. try:
  67. self.get_connection(store_address)
  68. DB.create_tables([MachineLearningModel])
  69. model = PipelinedModel(model_id, model_version)
  70. hash_ = model.packaging_model()
  71. with open(model.archive_model_file_path, 'rb') as fr, DB.connection_context():
  72. MachineLearningModel.delete().where(
  73. MachineLearningModel.f_model_id == model_id,
  74. MachineLearningModel.f_model_version == model_version,
  75. ).execute()
  76. LOGGER.info(f'Starting store model {model_id} {model_version}.')
  77. slice_index = 0
  78. while True:
  79. content = fr.read(SLICE_MAX_SIZE)
  80. if not content:
  81. break
  82. model_in_table = MachineLearningModel()
  83. model_in_table.f_model_id = model_id
  84. model_in_table.f_model_version = model_version
  85. model_in_table.f_content = serialize_b64(content, to_str=True)
  86. model_in_table.f_size = sys.getsizeof(model_in_table.f_content)
  87. model_in_table.f_slice_index = slice_index
  88. rows = model_in_table.save(force_insert=True)
  89. if not rows:
  90. raise IndexError(f'Save slice index {slice_index} failed')
  91. LOGGER.info(f'Saved slice index {slice_index} of model {model_id} {model_version}.')
  92. slice_index += 1
  93. except Exception as e:
  94. LOGGER.exception(e)
  95. raise Exception(f'Store model {model_id} {model_version} to mysql failed.')
  96. else:
  97. LOGGER.info(f'Store model {model_id} {model_version} to mysql successfully.')
  98. return hash_
  99. finally:
  100. self.close_connection()
  101. def restore(self, model_id: str, model_version: str, store_address: dict, force_update: bool = False, hash_: str = None):
  102. '''
  103. Restore model from mysql to local cache
  104. :param model_id:
  105. :param model_version:
  106. :param store_address:
  107. :return:
  108. '''
  109. model = PipelinedModel(model_id, model_version)
  110. self.get_connection(store_address)
  111. try:
  112. with DB.connection_context():
  113. models_in_tables = MachineLearningModel.select().where(
  114. MachineLearningModel.f_model_id == model_id,
  115. MachineLearningModel.f_model_version == model_version,
  116. ).order_by(MachineLearningModel.f_slice_index)
  117. with open(model.archive_model_file_path, 'wb') as fw:
  118. for models_in_table in models_in_tables:
  119. fw.write(deserialize_b64(models_in_table.f_content))
  120. if fw.tell() == 0:
  121. raise IndexError(f'Cannot found model in table.')
  122. model.unpack_model(model.archive_model_file_path, force_update, hash_)
  123. except Exception as e:
  124. LOGGER.exception(e)
  125. raise Exception(f'Restore model {model_id} {model_version} from mysql failed.')
  126. else:
  127. LOGGER.info(f'Restore model to {model.archive_model_file_path} from mysql successfully.')
  128. finally:
  129. self.close_connection()
  130. @staticmethod
  131. def get_connection(store_address: dict):
  132. store_address = deepcopy(store_address)
  133. store_address.pop('storage', None)
  134. database = store_address.pop('database')
  135. store_address = decrypt_database_config(store_address, 'password')
  136. DB.init(database, **store_address)
  137. @staticmethod
  138. def close_connection():
  139. if DB:
  140. try:
  141. DB.close()
  142. except Exception as e:
  143. LOGGER.exception(e)
  144. class MysqlComponentStorage(ComponentStorageBase):
  145. def __init__(self, database, user, password, host, port, **connect_kwargs):
  146. self.database = database
  147. self.user = user
  148. self.password = decrypt_database_password(password)
  149. self.host = host
  150. self.port = port
  151. self.connect_kwargs = connect_kwargs
  152. def __enter__(self):
  153. DB.init(self.database, user=self.user, password=self.password, host=self.host, port=self.port, **self.connect_kwargs)
  154. return self
  155. def __exit__(self, *exc):
  156. DB.close()
  157. def exists(self, party_model_id, model_version, component_name):
  158. try:
  159. with DB.connection_context():
  160. counts = MachineLearningComponent.select().where(
  161. MachineLearningComponent.f_party_model_id == party_model_id,
  162. MachineLearningComponent.f_model_version == model_version,
  163. MachineLearningComponent.f_component_name == component_name,
  164. ).count()
  165. return counts > 0
  166. except PeeweeException as e:
  167. # Table doesn't exist
  168. if e.args and e.args[0] == 1146:
  169. return False
  170. raise e
  171. def upload(self, party_model_id, model_version, component_name):
  172. DB.create_tables([MachineLearningComponent])
  173. pipelined_component = PipelinedComponent(party_model_id=party_model_id, model_version=model_version)
  174. filename, hash_ = pipelined_component.pack_component(component_name)
  175. with open(filename, 'rb') as fr, DB.connection_context():
  176. MachineLearningComponent.delete().where(
  177. MachineLearningComponent.f_party_model_id == party_model_id,
  178. MachineLearningComponent.f_model_version == model_version,
  179. MachineLearningComponent.f_component_name == component_name,
  180. ).execute()
  181. slice_index = 0
  182. while True:
  183. content = fr.read(SLICE_MAX_SIZE)
  184. if not content:
  185. break
  186. model_in_table = MachineLearningComponent()
  187. model_in_table.f_party_model_id = party_model_id
  188. model_in_table.f_model_version = model_version
  189. model_in_table.f_component_name = component_name
  190. model_in_table.f_content = serialize_b64(content, to_str=True)
  191. model_in_table.f_size = sys.getsizeof(model_in_table.f_content)
  192. model_in_table.f_slice_index = slice_index
  193. rows = model_in_table.save(force_insert=True)
  194. if not rows:
  195. raise IndexError(f'Save slice index {slice_index} failed')
  196. slice_index += 1
  197. return hash_
  198. def download(self, party_model_id, model_version, component_name, hash_=None):
  199. with DB.connection_context():
  200. models_in_tables = MachineLearningComponent.select().where(
  201. MachineLearningComponent.f_party_model_id == party_model_id,
  202. MachineLearningComponent.f_model_version == model_version,
  203. MachineLearningComponent.f_component_name == component_name,
  204. ).order_by(MachineLearningComponent.f_slice_index)
  205. pipelined_component = PipelinedComponent(party_model_id=party_model_id, model_version=model_version)
  206. with open(pipelined_component.get_archive_path(component_name), 'wb') as fw:
  207. for models_in_table in models_in_tables:
  208. fw.write(deserialize_b64(models_in_table.f_content))
  209. if fw.tell() == 0:
  210. raise IndexError(f'Cannot found component model in table.')
  211. pipelined_component.unpack_component(component_name, hash_)
  212. @DB.connection_context()
  213. def copy(self, party_model_id, model_version, component_name, source_model_version):
  214. now = current_timestamp()
  215. source = MachineLearningComponent.select(
  216. MachineLearningComponent.f_create_time,
  217. MachineLearningComponent.f_create_date,
  218. Value(now).alias('f_update_time'),
  219. Value(timestamp_to_date(now)).alias('f_update_date'),
  220. MachineLearningComponent.f_party_model_id,
  221. Value(model_version).alias('f_model_version'),
  222. MachineLearningComponent.f_component_name,
  223. MachineLearningComponent.f_size,
  224. MachineLearningComponent.f_content,
  225. MachineLearningComponent.f_slice_index,
  226. ).where(
  227. MachineLearningComponent.f_party_model_id == party_model_id,
  228. MachineLearningComponent.f_model_version == source_model_version,
  229. MachineLearningComponent.f_component_name == component_name,
  230. ).order_by(MachineLearningComponent.f_slice_index)
  231. rows = MachineLearningComponent.insert_from(source, (
  232. MachineLearningComponent.f_create_time,
  233. MachineLearningComponent.f_create_date,
  234. MachineLearningComponent.f_update_time,
  235. MachineLearningComponent.f_update_date,
  236. MachineLearningComponent.f_party_model_id,
  237. MachineLearningComponent.f_model_version,
  238. MachineLearningComponent.f_component_name,
  239. MachineLearningComponent.f_size,
  240. MachineLearningComponent.f_content,
  241. MachineLearningComponent.f_slice_index,
  242. )).execute()
  243. if not rows:
  244. raise IndexError(f'Copy component model failed.')
  245. class MachineLearningModel(DataBaseModel):
  246. f_model_id = CharField(max_length=100, index=True)
  247. f_model_version = CharField(max_length=100, index=True)
  248. f_size = BigIntegerField(default=0)
  249. f_content = LongTextField(default='')
  250. f_slice_index = IntegerField(default=0, index=True)
  251. class Meta:
  252. db_table = 't_machine_learning_model'
  253. primary_key = CompositeKey('f_model_id', 'f_model_version', 'f_slice_index')
  254. class MachineLearningComponent(DataBaseModel):
  255. f_party_model_id = CharField(max_length=100, index=True)
  256. f_model_version = CharField(max_length=100, index=True)
  257. f_component_name = CharField(max_length=100, index=True)
  258. f_size = BigIntegerField(default=0)
  259. f_content = LongTextField(default='')
  260. f_slice_index = IntegerField(default=0, index=True)
  261. class Meta:
  262. db_table = 't_machine_learning_component'
  263. indexes = (
  264. (('f_party_model_id', 'f_model_version', 'f_component_name', 'f_slice_index'), True),
  265. )