model_utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import glob
  17. from fate_arch.common.base_utils import json_loads
  18. from fate_flow.db.db_models import DB, MachineLearningModelInfo as MLModel
  19. from fate_flow.model.sync_model import SyncModel
  20. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  21. from fate_flow.scheduler.cluster_scheduler import ClusterScheduler
  22. from fate_flow.settings import ENABLE_MODEL_STORE, stat_logger
  23. from fate_flow.utils.base_utils import compare_version, get_fate_flow_directory
  24. def all_party_key(all_party):
  25. """
  26. Join all party as party key
  27. :param all_party:
  28. "role": {
  29. "guest": [9999],
  30. "host": [10000],
  31. "arbiter": [10000]
  32. }
  33. :return:
  34. """
  35. if not all_party:
  36. all_party_key = 'all'
  37. elif isinstance(all_party, dict):
  38. sorted_role_name = sorted(all_party.keys())
  39. all_party_key = '#'.join([
  40. ('%s-%s' % (
  41. role_name,
  42. '_'.join([str(p) for p in sorted(set(all_party[role_name]))]))
  43. )
  44. for role_name in sorted_role_name])
  45. else:
  46. all_party_key = None
  47. return all_party_key
  48. def gen_party_model_id(model_id, role, party_id):
  49. return '#'.join([role, str(party_id), model_id]) if model_id else None
  50. def gen_model_id(all_party):
  51. return '#'.join([all_party_key(all_party), "model"])
  52. @DB.connection_context()
  53. def query_model_info_from_db(query_filters=None, **kwargs):
  54. conditions = []
  55. filters = []
  56. for k, v in kwargs.items():
  57. k = f'f_{k}'
  58. if hasattr(MLModel, k):
  59. conditions.append(getattr(MLModel, k) == v)
  60. for k in query_filters:
  61. k = f'f_{k}'
  62. if hasattr(MLModel, k):
  63. filters.append(getattr(MLModel, k))
  64. models = MLModel.select(*filters)
  65. if conditions:
  66. models = models.where(*conditions)
  67. models = [model.to_dict() for model in models]
  68. if not models:
  69. return 100, 'Query model info failed, cannot find model from db.', []
  70. return 0, 'Query model info from db success.', models
  71. def query_model_info_from_file(model_id='*', model_version='*', role='*', party_id='*', query_filters=None, save_to_db=False, **kwargs):
  72. fp_list = glob.glob(f"{get_fate_flow_directory('model_local_cache')}/{role}#{party_id}#{model_id}/{model_version}")
  73. models = []
  74. for fp in fp_list:
  75. _, party_model_id, model_version = fp.rsplit('/', 2)
  76. role, party_id, model_id = party_model_id.split('#', 2)
  77. pipeline_model = PipelinedModel(model_id=party_model_id, model_version=model_version)
  78. if not pipeline_model.exists():
  79. continue
  80. model_info = gather_model_info_data(pipeline_model)
  81. if save_to_db:
  82. try:
  83. save_model_info(model_info)
  84. except Exception as e:
  85. stat_logger.exception(e)
  86. if query_filters:
  87. model_info = {k: v for k, v in model_info.items() if k in query_filters}
  88. models.append(model_info)
  89. if not models:
  90. return 100, 'Query model info failed, cannot find model from local model files.', []
  91. return 0, 'Query model info from local model success.', models
  92. def gather_model_info_data(model: PipelinedModel):
  93. pipeline = model.read_pipeline_model()
  94. model_info = {}
  95. for attr, field in pipeline.ListFields():
  96. if isinstance(field, bytes):
  97. field = json_loads(field)
  98. model_info[f'f_{attr.name}'] = field
  99. model_info['f_job_id'] = model_info['f_model_version']
  100. model_info['f_role'] = model.role
  101. model_info['f_party_id'] = model.party_id
  102. # backward compatibility
  103. model_info['f_runtime_conf'] = model_info['f_train_runtime_conf']
  104. model_info['f_size'] = model.calculate_model_file_size()
  105. if compare_version(model_info['f_fate_version'], '1.5.1') == 'lt':
  106. model_info['f_roles'] = model_info.get('f_train_runtime_conf', {}).get('role', {})
  107. model_info['f_initiator_role'] = model_info.get('f_train_runtime_conf', {}).get('initiator', {}).get('role')
  108. model_info['f_initiator_party_id'] = model_info.get('f_train_runtime_conf', {}).get('initiator', {}).get('party_id')
  109. return model_info
  110. def query_model_info(**kwargs):
  111. file_only = kwargs.pop('file_only', False)
  112. kwargs['query_filters'] = set(kwargs['query_filters']) if kwargs.get('query_filters') else set()
  113. if not file_only:
  114. retcode, retmsg, data = query_model_info_from_db(**kwargs)
  115. if not retcode:
  116. return retcode, retmsg, data
  117. kwargs['save_to_db'] = True
  118. retcode, retmsg, data = query_model_info_from_file(**kwargs)
  119. if not retcode:
  120. return retcode, retmsg, data
  121. return 100, (
  122. 'Query model info failed, cannot find model from db and local model files. '
  123. 'Try use both model id and model version to query model info from local models.'
  124. ), []
  125. def save_model_info(model_info):
  126. model_info = {k if k.startswith('f_') else f'f_{k}': v for k, v in model_info.items()}
  127. with DB.connection_context():
  128. MLModel.insert(**model_info).on_conflict(preserve=(
  129. 'f_update_time',
  130. 'f_update_date',
  131. *model_info.keys(),
  132. )).execute()
  133. if ENABLE_MODEL_STORE:
  134. sync_model = SyncModel(
  135. role=model_info['f_role'], party_id=model_info['f_party_id'],
  136. model_id=model_info['f_model_id'], model_version=model_info['f_model_version'],
  137. )
  138. sync_model.upload(True)
  139. ClusterScheduler.cluster_command('/model/service/register', {
  140. 'party_model_id': gen_party_model_id(
  141. model_info['f_model_id'],
  142. model_info['f_role'],
  143. model_info['f_party_id'],
  144. ),
  145. 'model_version': model_info['f_model_version'],
  146. })
  147. def check_if_parent_model(pipeline):
  148. if compare_version(pipeline.fate_version, '1.5.0') == 'gt':
  149. if pipeline.parent:
  150. return True
  151. return False
  152. def check_before_deploy(pipeline_model: PipelinedModel):
  153. pipeline = pipeline_model.read_pipeline_model()
  154. if compare_version(pipeline.fate_version, '1.5.0') == 'gt':
  155. if pipeline.parent:
  156. return True
  157. elif compare_version(pipeline.fate_version, '1.5.0') == 'eq':
  158. return True
  159. return False
  160. def check_if_deployed(role, party_id, model_id, model_version):
  161. party_model_id = gen_party_model_id(model_id=model_id, role=role, party_id=party_id)
  162. pipeline_model = PipelinedModel(model_id=party_model_id, model_version=model_version)
  163. if not pipeline_model.exists():
  164. raise FileNotFoundError(f"Model {party_model_id} {model_version} not exists in model local cache.")
  165. pipeline = pipeline_model.read_pipeline_model()
  166. if compare_version(pipeline.fate_version, '1.5.0') == 'gt':
  167. train_runtime_conf = json_loads(pipeline.train_runtime_conf)
  168. if str(train_runtime_conf.get('dsl_version', '1')) != '1':
  169. if pipeline.parent:
  170. return False
  171. return True
  172. @DB.connection_context()
  173. def models_group_by_party_model_id_and_model_version():
  174. args = [
  175. MLModel.f_role,
  176. MLModel.f_party_id,
  177. MLModel.f_model_id,
  178. MLModel.f_model_version,
  179. ]
  180. return MLModel.select(*args).group_by(*args)
  181. @DB.connection_context()
  182. def get_job_configuration_from_model(job_id, role, party_id):
  183. retcode, retmsg, data = query_model_info(
  184. model_version=job_id, role=role, party_id=party_id,
  185. query_filters=['train_dsl', 'dsl', 'train_runtime_conf', 'runtime_conf'],
  186. )
  187. if not data:
  188. return {}, {}, {}
  189. dsl = data[0].get('train_dsl') if data[0].get('train_dsl') else data[0].get('dsl')
  190. runtime_conf = data[0].get('runtime_conf')
  191. train_runtime_conf = data[0].get('train_runtime_conf')
  192. return dsl, runtime_conf, train_runtime_conf