db_models.py 21 KB


  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import datetime
  17. import inspect
  18. import os
  19. import sys
  20. from functools import wraps
  21. from peewee import (
  22. BigAutoField, BigIntegerField, BooleanField, CharField,
  23. CompositeKey, Insert, IntegerField, TextField,
  24. )
  25. from playhouse.hybrid import hybrid_property
  26. from playhouse.pool import PooledMySQLDatabase
  27. from fate_arch.common import file_utils
  28. from fate_arch.metastore.base_model import (
  29. BaseModel, DateTimeField, JSONField, ListField,
  30. LongTextField, SerializedField, SerializedType,
  31. )
  32. from fate_flow.db.runtime_config import RuntimeConfig
  33. from fate_flow.settings import DATABASE, IS_STANDALONE, stat_logger
  34. from fate_flow.utils.log_utils import getLogger
  35. from fate_flow.utils.object_utils import from_dict_hook
  36. LOGGER = getLogger()
  37. class JsonSerializedField(SerializedField):
  38. def __init__(self, object_hook=from_dict_hook, object_pairs_hook=None, **kwargs):
  39. super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
  40. object_pairs_hook=object_pairs_hook, **kwargs)
  41. def singleton(cls, *args, **kw):
  42. instances = {}
  43. def _singleton():
  44. key = str(cls) + str(os.getpid())
  45. if key not in instances:
  46. instances[key] = cls(*args, **kw)
  47. return instances[key]
  48. return _singleton
  49. @singleton
  50. class BaseDataBase:
  51. def __init__(self):
  52. database_config = DATABASE.copy()
  53. db_name = database_config.pop("name")
  54. if IS_STANDALONE and not bool(int(os.environ.get("FORCE_USE_MYSQL", 0))):
  55. # sqlite does not support other options
  56. Insert.on_conflict = lambda self, *args, **kwargs: self.on_conflict_replace()
  57. from playhouse.apsw_ext import APSWDatabase
  58. self.database_connection = APSWDatabase(file_utils.get_project_base_directory("fate_sqlite.db"))
  59. RuntimeConfig.init_config(USE_LOCAL_DATABASE=True)
  60. stat_logger.info('init sqlite database on standalone mode successfully')
  61. else:
  62. self.database_connection = PooledMySQLDatabase(db_name, **database_config)
  63. stat_logger.info('init mysql database on cluster mode successfully')
  64. class DatabaseLock:
  65. def __init__(self, lock_name, timeout=10, db=None):
  66. self.lock_name = lock_name
  67. self.timeout = int(timeout)
  68. self.db = db if db else DB
  69. def lock(self):
  70. # SQL parameters only support %s format placeholders
  71. cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
  72. ret = cursor.fetchone()
  73. if ret[0] == 0:
  74. raise Exception(f'acquire mysql lock {self.lock_name} timeout')
  75. elif ret[0] == 1:
  76. return True
  77. else:
  78. raise Exception(f'failed to acquire lock {self.lock_name}')
  79. def unlock(self):
  80. cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name, ))
  81. ret = cursor.fetchone()
  82. if ret[0] == 0:
  83. raise Exception(f'mysql lock {self.lock_name} was not established by this thread')
  84. elif ret[0] == 1:
  85. return True
  86. else:
  87. raise Exception(f'mysql lock {self.lock_name} does not exist')
  88. def __enter__(self):
  89. if isinstance(self.db, PooledMySQLDatabase):
  90. self.lock()
  91. return self
  92. def __exit__(self, exc_type, exc_val, exc_tb):
  93. if isinstance(self.db, PooledMySQLDatabase):
  94. self.unlock()
  95. def __call__(self, func):
  96. @wraps(func)
  97. def magic(*args, **kwargs):
  98. with self:
  99. return func(*args, **kwargs)
  100. return magic
  101. DB = BaseDataBase().database_connection
  102. DB.lock = DatabaseLock
  103. def close_connection():
  104. try:
  105. if DB:
  106. DB.close()
  107. except Exception as e:
  108. LOGGER.exception(e)
  109. class DataBaseModel(BaseModel):
  110. class Meta:
  111. database = DB
  112. @DB.connection_context()
  113. def init_database_tables():
  114. members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
  115. table_objs = []
  116. create_failed_list = []
  117. for name, obj in members:
  118. if obj != DataBaseModel and issubclass(obj, DataBaseModel):
  119. table_objs.append(obj)
  120. LOGGER.info(f"start create table {obj.__name__}")
  121. try:
  122. obj.create_table()
  123. LOGGER.info(f"create table success: {obj.__name__}")
  124. except Exception as e:
  125. LOGGER.exception(e)
  126. create_failed_list.append(obj.__name__)
  127. if create_failed_list:
  128. LOGGER.info(f"create tables failed: {create_failed_list}")
  129. raise Exception(f"create tables failed: {create_failed_list}")
  130. def fill_db_model_object(model_object, human_model_dict):
  131. for k, v in human_model_dict.items():
  132. attr_name = 'f_%s' % k
  133. if hasattr(model_object.__class__, attr_name):
  134. setattr(model_object, attr_name, v)
  135. return model_object
  136. class Job(DataBaseModel):
  137. # multi-party common configuration
  138. f_user_id = CharField(max_length=25, null=True)
  139. f_job_id = CharField(max_length=25, index=True)
  140. f_name = CharField(max_length=500, null=True, default='')
  141. f_description = TextField(null=True, default='')
  142. f_tag = CharField(max_length=50, null=True, default='')
  143. f_dsl = JSONField()
  144. f_runtime_conf = JSONField()
  145. f_runtime_conf_on_party = JSONField()
  146. f_train_runtime_conf = JSONField(null=True)
  147. f_roles = JSONField()
  148. f_initiator_role = CharField(max_length=50)
  149. f_initiator_party_id = CharField(max_length=50)
  150. f_status = CharField(max_length=50)
  151. f_status_code = IntegerField(null=True)
  152. f_user = JSONField()
  153. # this party configuration
  154. f_role = CharField(max_length=50, index=True)
  155. f_party_id = CharField(max_length=10, index=True)
  156. f_is_initiator = BooleanField(null=True, default=False)
  157. f_progress = IntegerField(null=True, default=0)
  158. f_ready_signal = BooleanField(default=False)
  159. f_ready_time = BigIntegerField(null=True)
  160. f_cancel_signal = BooleanField(default=False)
  161. f_cancel_time = BigIntegerField(null=True)
  162. f_rerun_signal = BooleanField(default=False)
  163. f_end_scheduling_updates = IntegerField(null=True, default=0)
  164. f_engine_name = CharField(max_length=50, null=True)
  165. f_engine_type = CharField(max_length=10, null=True)
  166. f_cores = IntegerField(default=0)
  167. f_memory = IntegerField(default=0) # MB
  168. f_remaining_cores = IntegerField(default=0)
  169. f_remaining_memory = IntegerField(default=0) # MB
  170. f_resource_in_use = BooleanField(default=False)
  171. f_apply_resource_time = BigIntegerField(null=True)
  172. f_return_resource_time = BigIntegerField(null=True)
  173. f_inheritance_info = JSONField(null=True)
  174. f_inheritance_status = CharField(max_length=50, null=True)
  175. f_start_time = BigIntegerField(null=True)
  176. f_start_date = DateTimeField(null=True)
  177. f_end_time = BigIntegerField(null=True)
  178. f_end_date = DateTimeField(null=True)
  179. f_elapsed = BigIntegerField(null=True)
  180. class Meta:
  181. db_table = "t_job"
  182. primary_key = CompositeKey('f_job_id', 'f_role', 'f_party_id')
  183. class Task(DataBaseModel):
  184. # multi-party common configuration
  185. f_job_id = CharField(max_length=25, index=True)
  186. f_component_name = TextField()
  187. f_component_module = CharField(max_length=200)
  188. f_task_id = CharField(max_length=100)
  189. f_task_version = BigIntegerField()
  190. f_initiator_role = CharField(max_length=50)
  191. f_initiator_party_id = CharField(max_length=50, default=-1)
  192. f_federated_mode = CharField(max_length=10)
  193. f_federated_status_collect_type = CharField(max_length=10)
  194. f_status = CharField(max_length=50, index=True)
  195. f_status_code = IntegerField(null=True)
  196. f_auto_retries = IntegerField(default=0)
  197. f_auto_retry_delay = IntegerField(default=0)
  198. # this party configuration
  199. f_role = CharField(max_length=50, index=True)
  200. f_party_id = CharField(max_length=10, index=True)
  201. f_run_on_this_party = BooleanField(null=True, index=True, default=False)
  202. f_worker_id = CharField(null=True, max_length=100)
  203. f_cmd = JSONField(null=True)
  204. f_run_ip = CharField(max_length=100, null=True)
  205. f_run_port = IntegerField(null=True)
  206. f_run_pid = IntegerField(null=True)
  207. f_party_status = CharField(max_length=50)
  208. f_provider_info = JSONField()
  209. f_component_parameters = JSONField()
  210. f_engine_conf = JSONField(null=True)
  211. f_kill_status = BooleanField(default=False)
  212. f_error_report = TextField(default="")
  213. f_start_time = BigIntegerField(null=True)
  214. f_start_date = DateTimeField(null=True)
  215. f_end_time = BigIntegerField(null=True)
  216. f_end_date = DateTimeField(null=True)
  217. f_elapsed = BigIntegerField(null=True)
  218. class Meta:
  219. db_table = "t_task"
  220. primary_key = CompositeKey('f_job_id', 'f_task_id', 'f_task_version', 'f_role', 'f_party_id')
  221. class TrackingMetric(DataBaseModel):
  222. _mapper = {}
  223. @classmethod
  224. def model(cls, table_index=None, date=None):
  225. if not table_index:
  226. table_index = date.strftime(
  227. '%Y%m%d') if date else datetime.datetime.now().strftime(
  228. '%Y%m%d')
  229. class_name = 'TrackingMetric_%s' % table_index
  230. ModelClass = TrackingMetric._mapper.get(class_name, None)
  231. if ModelClass is None:
  232. class Meta:
  233. db_table = '%s_%s' % ('t_tracking_metric', table_index)
  234. attrs = {'__module__': cls.__module__, 'Meta': Meta}
  235. ModelClass = type("%s_%s" % (cls.__name__, table_index), (cls,),
  236. attrs)
  237. TrackingMetric._mapper[class_name] = ModelClass
  238. return ModelClass()
  239. f_id = BigAutoField(primary_key=True)
  240. f_job_id = CharField(max_length=25, index=True)
  241. f_component_name = CharField(max_length=30, index=True)
  242. f_task_id = CharField(max_length=100, null=True)
  243. f_task_version = BigIntegerField(null=True)
  244. f_role = CharField(max_length=10, index=True)
  245. f_party_id = CharField(max_length=10)
  246. f_metric_namespace = CharField(max_length=80, index=True)
  247. f_metric_name = CharField(max_length=80, index=True)
  248. f_key = CharField(max_length=200)
  249. f_value = LongTextField()
  250. f_type = IntegerField() # 0 is data, 1 is meta
  251. class TrackingOutputDataInfo(DataBaseModel):
  252. _mapper = {}
  253. @classmethod
  254. def model(cls, table_index=None, date=None):
  255. if not table_index:
  256. table_index = date.strftime(
  257. '%Y%m%d') if date else datetime.datetime.now().strftime(
  258. '%Y%m%d')
  259. class_name = 'TrackingOutputDataInfo_%s' % table_index
  260. ModelClass = TrackingOutputDataInfo._mapper.get(class_name, None)
  261. if ModelClass is None:
  262. class Meta:
  263. db_table = '%s_%s' % ('t_tracking_output_data_info', table_index)
  264. primary_key = CompositeKey(
  265. 'f_job_id', 'f_task_id', 'f_task_version',
  266. 'f_data_name', 'f_role', 'f_party_id',
  267. )
  268. attrs = {'__module__': cls.__module__, 'Meta': Meta}
  269. ModelClass = type("%s_%s" % (cls.__name__, table_index), (cls,),
  270. attrs)
  271. TrackingOutputDataInfo._mapper[class_name] = ModelClass
  272. return ModelClass()
  273. # multi-party common configuration
  274. f_job_id = CharField(max_length=25, index=True)
  275. f_component_name = TextField()
  276. f_task_id = CharField(max_length=100, null=True, index=True)
  277. f_task_version = BigIntegerField(null=True)
  278. f_data_name = CharField(max_length=30)
  279. # this party configuration
  280. f_role = CharField(max_length=50, index=True)
  281. f_party_id = CharField(max_length=10, index=True)
  282. f_table_name = CharField(max_length=500, null=True)
  283. f_table_namespace = CharField(max_length=500, null=True)
  284. f_description = TextField(null=True, default='')
  285. class MachineLearningModelInfo(DataBaseModel):
  286. f_role = CharField(max_length=50)
  287. f_party_id = CharField(max_length=10)
  288. f_roles = JSONField(default={})
  289. f_job_id = CharField(max_length=25, index=True)
  290. f_model_id = CharField(max_length=100, index=True)
  291. f_model_version = CharField(max_length=100, index=True)
  292. f_size = BigIntegerField(default=0)
  293. f_initiator_role = CharField(max_length=50)
  294. f_initiator_party_id = CharField(max_length=50, default=-1)
  295. # TODO: deprecated. use f_train_runtime_conf instead
  296. f_runtime_conf = JSONField(default={})
  297. f_train_dsl = JSONField(default={})
  298. f_train_runtime_conf = JSONField(default={})
  299. f_runtime_conf_on_party = JSONField(default={})
  300. f_inference_dsl = JSONField(default={})
  301. f_fate_version = CharField(max_length=10, null=True, default='')
  302. f_parent = BooleanField(null=True, default=None)
  303. f_parent_info = JSONField(default={})
  304. # loaded times in api /model/load/do
  305. f_loaded_times = IntegerField(default=0)
  306. # imported from api /model/import
  307. f_imported = IntegerField(default=0)
  308. f_archive_sha256 = CharField(max_length=100, null=True)
  309. f_archive_from_ip = CharField(max_length=100, null=True)
  310. @hybrid_property
  311. def f_party_model_id(self):
  312. return '#'.join([self.f_role, self.f_party_id, self.f_model_id])
  313. class Meta:
  314. db_table = "t_machine_learning_model_info"
  315. primary_key = CompositeKey('f_role', 'f_party_id', 'f_model_id', 'f_model_version')
  316. class DataTableTracking(DataBaseModel):
  317. f_table_id = BigAutoField(primary_key=True)
  318. f_table_name = CharField(max_length=300, null=True)
  319. f_table_namespace = CharField(max_length=300, null=True)
  320. f_job_id = CharField(max_length=25, index=True, null=True)
  321. f_have_parent = BooleanField(default=False)
  322. f_parent_number = IntegerField(default=0)
  323. f_parent_table_name = CharField(max_length=500, null=True)
  324. f_parent_table_namespace = CharField(max_length=500, null=True)
  325. f_source_table_name = CharField(max_length=500, null=True)
  326. f_source_table_namespace = CharField(max_length=500, null=True)
  327. class Meta:
  328. db_table = "t_data_table_tracking"
  329. class CacheRecord(DataBaseModel):
  330. f_cache_key = CharField(max_length=500)
  331. f_cache = JsonSerializedField()
  332. f_job_id = CharField(max_length=25, index=True, null=True)
  333. f_role = CharField(max_length=50, index=True, null=True)
  334. f_party_id = CharField(max_length=10, index=True, null=True)
  335. f_component_name = TextField(null=True)
  336. f_task_id = CharField(max_length=100, null=True)
  337. f_task_version = BigIntegerField(null=True, index=True)
  338. f_cache_name = CharField(max_length=50, null=True)
  339. t_ttl = BigIntegerField(default=0)
  340. class Meta:
  341. db_table = "t_cache_record"
  342. class ModelTag(DataBaseModel):
  343. f_id = BigAutoField(primary_key=True)
  344. f_m_id = CharField(max_length=25, null=False)
  345. f_t_id = BigIntegerField(null=False)
  346. class Meta:
  347. db_table = "t_model_tag"
  348. class Tag(DataBaseModel):
  349. f_id = BigAutoField(primary_key=True)
  350. f_name = CharField(max_length=100, unique=True)
  351. f_desc = TextField(null=True)
  352. class Meta:
  353. db_table = "t_tags"
  354. class ComponentSummary(DataBaseModel):
  355. _mapper = {}
  356. @classmethod
  357. def model(cls, table_index=None, date=None):
  358. if not table_index:
  359. table_index = date.strftime(
  360. '%Y%m%d') if date else datetime.datetime.now().strftime(
  361. '%Y%m%d')
  362. class_name = 'ComponentSummary_%s' % table_index
  363. ModelClass = TrackingMetric._mapper.get(class_name, None)
  364. if ModelClass is None:
  365. class Meta:
  366. db_table = '%s_%s' % ('t_component_summary', table_index)
  367. attrs = {'__module__': cls.__module__, 'Meta': Meta}
  368. ModelClass = type("%s_%s" % (cls.__name__, table_index), (cls,), attrs)
  369. ComponentSummary._mapper[class_name] = ModelClass
  370. return ModelClass()
  371. f_id = BigAutoField(primary_key=True)
  372. f_job_id = CharField(max_length=25, index=True)
  373. f_role = CharField(max_length=25, index=True)
  374. f_party_id = CharField(max_length=10, index=True)
  375. f_component_name = CharField(max_length=50)
  376. f_task_id = CharField(max_length=50, null=True, index=True)
  377. f_task_version = CharField(max_length=50, null=True)
  378. f_summary = LongTextField()
  379. class EngineRegistry(DataBaseModel):
  380. f_engine_type = CharField(max_length=10, index=True)
  381. f_engine_name = CharField(max_length=50, index=True)
  382. f_engine_entrance = CharField(max_length=50, index=True)
  383. f_engine_config = JSONField()
  384. f_cores = IntegerField()
  385. f_memory = IntegerField() # MB
  386. f_remaining_cores = IntegerField()
  387. f_remaining_memory = IntegerField() # MB
  388. f_nodes = IntegerField()
  389. class Meta:
  390. db_table = "t_engine_registry"
  391. primary_key = CompositeKey('f_engine_name', 'f_engine_type')
  392. # component registry
  393. class ComponentRegistryInfo(DataBaseModel):
  394. f_provider_name = CharField(max_length=20, index=True)
  395. f_version = CharField(max_length=10, index=True)
  396. f_component_name = CharField(max_length=30, index=True)
  397. f_module = CharField(max_length=128)
  398. class Meta:
  399. db_table = "t_component_registry"
  400. primary_key = CompositeKey('f_provider_name', 'f_version', 'f_component_name')
  401. class ComponentProviderInfo(DataBaseModel):
  402. f_provider_name = CharField(max_length=20, index=True)
  403. f_version = CharField(max_length=10, index=True)
  404. f_class_path = JSONField()
  405. f_path = CharField(max_length=128, null=False)
  406. f_python = CharField(max_length=128, null=False)
  407. class Meta:
  408. db_table = "t_component_provider_info"
  409. primary_key = CompositeKey('f_provider_name', 'f_version')
  410. class ComponentInfo(DataBaseModel):
  411. f_component_name = CharField(max_length=30, primary_key=True)
  412. f_component_alias = JSONField()
  413. f_default_provider = CharField(max_length=20)
  414. f_support_provider = ListField(null=True)
  415. class Meta:
  416. db_table = "t_component_info"
  417. class WorkerInfo(DataBaseModel):
  418. f_worker_id = CharField(max_length=100, primary_key=True)
  419. f_worker_name = CharField(max_length=50, index=True)
  420. f_job_id = CharField(max_length=25, index=True)
  421. f_task_id = CharField(max_length=100)
  422. f_task_version = BigIntegerField(index=True)
  423. f_role = CharField(max_length=50)
  424. f_party_id = CharField(max_length=10, index=True)
  425. f_run_ip = CharField(max_length=100, null=True)
  426. f_run_pid = IntegerField(null=True)
  427. f_http_port = IntegerField(null=True)
  428. f_grpc_port = IntegerField(null=True)
  429. f_config = JSONField(null=True)
  430. f_cmd = JSONField(null=True)
  431. f_start_time = BigIntegerField(null=True)
  432. f_start_date = DateTimeField(null=True)
  433. f_end_time = BigIntegerField(null=True)
  434. f_end_date = DateTimeField(null=True)
  435. class Meta:
  436. db_table = "t_worker"
  437. class DependenciesStorageMeta(DataBaseModel):
  438. f_storage_engine = CharField(max_length=30)
  439. f_type = CharField(max_length=20)
  440. f_version = CharField(max_length=10, index=True)
  441. f_storage_path = CharField(max_length=256, null=True)
  442. f_snapshot_time = BigIntegerField(null=True)
  443. f_fate_flow_snapshot_time = BigIntegerField(null=True)
  444. f_dependencies_conf = JSONField(null=True)
  445. f_upload_status = BooleanField(default=False)
  446. f_pid = IntegerField(null=True)
  447. class Meta:
  448. db_table = "t_dependencies_storage_meta"
  449. primary_key = CompositeKey('f_storage_engine', 'f_type', 'f_version')
  450. class ServerRegistryInfo(DataBaseModel):
  451. f_server_name = CharField(max_length=30, index=True)
  452. f_host = CharField(max_length=30)
  453. f_port = IntegerField()
  454. f_protocol = CharField(max_length=10)
  455. class Meta:
  456. db_table = "t_server_registry_info"
  457. class ServiceRegistryInfo(DataBaseModel):
  458. f_server_name = CharField(max_length=30)
  459. f_service_name = CharField(max_length=30)
  460. f_url = CharField(max_length=100)
  461. f_method = CharField(max_length=10)
  462. f_params = JSONField(null=True)
  463. f_data = JSONField(null=True)
  464. f_headers = JSONField(null=True)
  465. class Meta:
  466. db_table = "t_service_registry_info"
  467. primary_key = CompositeKey('f_server_name', 'f_service_name')
  468. class SiteKeyInfo(DataBaseModel):
  469. f_party_id = CharField(max_length=10, index=True)
  470. f_key_name = CharField(max_length=10, index=True)
  471. f_key = LongTextField()
  472. class Meta:
  473. db_table = "t_site_key_info"
  474. primary_key = CompositeKey('f_party_id', 'f_key_name')
  475. class PipelineComponentMeta(DataBaseModel):
  476. f_model_id = CharField(max_length=100, index=True)
  477. f_model_version = CharField(max_length=100, index=True)
  478. f_role = CharField(max_length=50, index=True)
  479. f_party_id = CharField(max_length=10, index=True)
  480. f_component_name = CharField(max_length=100, index=True)
  481. f_component_module_name = CharField(max_length=100)
  482. f_model_alias = CharField(max_length=100, index=True)
  483. f_model_proto_index = JSONField(null=True)
  484. f_run_parameters = JSONField(null=True)
  485. f_archive_sha256 = CharField(max_length=100, null=True)
  486. f_archive_from_ip = CharField(max_length=100, null=True)
  487. class Meta:
  488. db_table = 't_pipeline_component_meta'
  489. indexes = (
  490. (('f_model_id', 'f_model_version', 'f_role', 'f_party_id', 'f_component_name'), True),
  491. )