deploy_model.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import shutil
  17. from fate_arch.common.base_utils import json_dumps, json_loads
  18. from fate_flow.db.db_models import PipelineComponentMeta
  19. from fate_flow.model.checkpoint import CheckpointManager
  20. from fate_flow.model.sync_model import SyncComponent, SyncModel
  21. from fate_flow.operation.job_saver import JobSaver
  22. from fate_flow.pipelined_model.pipelined_model import PipelinedModel
  23. from fate_flow.settings import ENABLE_MODEL_STORE, stat_logger
  24. from fate_flow.utils.base_utils import compare_version
  25. from fate_flow.utils.config_adapter import JobRuntimeConfigAdapter
  26. from fate_flow.utils.model_utils import (
  27. check_before_deploy, gather_model_info_data,
  28. gen_party_model_id, save_model_info,
  29. )
  30. from fate_flow.utils.schedule_utils import get_dsl_parser_by_version
  31. def deploy(config_data):
  32. model_id = config_data['model_id']
  33. model_version = config_data['model_version']
  34. local_role = config_data['local']['role']
  35. local_party_id = config_data['local']['party_id']
  36. child_model_version = config_data['child_model_version']
  37. components_checkpoint = config_data.get('components_checkpoint', {})
  38. warning_msg = ""
  39. try:
  40. if ENABLE_MODEL_STORE:
  41. sync_model = SyncModel(
  42. role=local_role, party_id=local_party_id,
  43. model_id=model_id, model_version=model_version,
  44. )
  45. if sync_model.remote_exists():
  46. sync_model.download(True)
  47. party_model_id = gen_party_model_id(
  48. model_id=model_id,
  49. role=local_role,
  50. party_id=local_party_id,
  51. )
  52. source_model = PipelinedModel(party_model_id, model_version)
  53. deploy_model = PipelinedModel(party_model_id, child_model_version)
  54. if not source_model.exists():
  55. raise FileNotFoundError(f'Can not found {model_id} {model_version} model local cache.')
  56. if not check_before_deploy(source_model):
  57. raise Exception('Child model could not be deployed.')
  58. pipeline_model = source_model.read_pipeline_model()
  59. pipeline_model.model_version = child_model_version
  60. train_runtime_conf = json_loads(pipeline_model.train_runtime_conf)
  61. dsl_version = int(train_runtime_conf.get('dsl_version', 1))
  62. parser = get_dsl_parser_by_version(dsl_version)
  63. inference_dsl = config_data.get('predict_dsl', config_data.get('dsl'))
  64. if inference_dsl is not None:
  65. if dsl_version == 1:
  66. raise KeyError("'predict_dsl' is not supported in DSL v1")
  67. if 'cpn_list' in config_data:
  68. raise KeyError("'cpn_list' should not be set when 'predict_dsl' is set")
  69. if not isinstance(inference_dsl, dict):
  70. inference_dsl = json_loads(inference_dsl)
  71. else:
  72. if dsl_version == 1:
  73. if 'cpn_list' in config_data:
  74. raise KeyError("'cpn_list' is not supported in DSL v1")
  75. inference_dsl, warning_msg = parser.convert_dsl_v1_to_v2(
  76. json_loads(pipeline_model.inference_dsl),
  77. )
  78. else:
  79. train_dsl = json_loads(pipeline_model.train_dsl)
  80. inference_dsl = parser.deploy_component(
  81. config_data.get(
  82. 'cpn_list',
  83. list(train_dsl.get('components', {}).keys()),
  84. ),
  85. train_dsl,
  86. )
  87. cpn_list = list(inference_dsl.get('components', {}).keys())
  88. if dsl_version == 1:
  89. from fate_flow.db.component_registry import ComponentRegistry
  90. job_providers = parser.get_job_providers(
  91. dsl=inference_dsl,
  92. provider_detail=ComponentRegistry.REGISTRY,
  93. )
  94. train_runtime_conf = parser.convert_conf_v1_to_v2(
  95. train_runtime_conf,
  96. {
  97. cpn: parser.parse_component_role_parameters(
  98. component=cpn,
  99. dsl=inference_dsl,
  100. runtime_conf=train_runtime_conf,
  101. provider_detail=ComponentRegistry.REGISTRY,
  102. provider_name=job_providers[cpn]['provider']['name'],
  103. provider_version=job_providers[cpn]['provider']['version'],
  104. ) for cpn in cpn_list
  105. }
  106. )
  107. parser = get_dsl_parser_by_version()
  108. parser.verify_dsl(inference_dsl, 'predict')
  109. inference_dsl = JobSaver.fill_job_inference_dsl(
  110. job_id=model_version, role=local_role, party_id=local_party_id,
  111. dsl_parser=parser, origin_inference_dsl=inference_dsl,
  112. )
  113. pipeline_model.inference_dsl = json_dumps(inference_dsl, byte=True)
  114. train_runtime_conf = JobRuntimeConfigAdapter(
  115. train_runtime_conf,
  116. ).update_model_id_version(
  117. model_version=child_model_version,
  118. )
  119. pipeline_model.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)
  120. if compare_version(pipeline_model.fate_version, '1.5.0') == 'gt':
  121. runtime_conf_on_party = json_loads(pipeline_model.runtime_conf_on_party)
  122. runtime_conf_on_party['job_parameters']['model_version'] = child_model_version
  123. pipeline_model.runtime_conf_on_party = json_dumps(runtime_conf_on_party, byte=True)
  124. pipeline_model.parent = False
  125. pipeline_model.parent_info = json_dumps({
  126. 'parent_model_id': model_id,
  127. 'parent_model_version': model_version,
  128. }, byte=True)
  129. query_args = (
  130. PipelineComponentMeta.f_component_name.in_(cpn_list),
  131. )
  132. query = source_model.pipelined_component.get_define_meta_from_db(*query_args)
  133. for row in query:
  134. shutil.copytree(
  135. source_model.pipelined_component.variables_data_path / row.f_component_name,
  136. deploy_model.pipelined_component.variables_data_path / row.f_component_name,
  137. )
  138. source_model.pipelined_component.replicate_define_meta({
  139. 'f_model_version': child_model_version,
  140. 'f_archive_sha256': None,
  141. 'f_archive_from_ip': None,
  142. }, query_args)
  143. if ENABLE_MODEL_STORE:
  144. for row in query:
  145. sync_component = SyncComponent(
  146. role=local_role, party_id=local_party_id,
  147. model_id=model_id, model_version=child_model_version,
  148. component_name=row.f_component_name,
  149. )
  150. sync_component.copy(model_version, row.f_archive_sha256)
  151. deploy_model.save_pipeline_model(pipeline_model)
  152. for row in query:
  153. step_index = components_checkpoint.get(row.f_component_name, {}).get('step_index')
  154. step_name = components_checkpoint.get(row.f_component_name, {}).get('step_name')
  155. if step_index is not None:
  156. step_index = int(step_index)
  157. step_name = None
  158. elif step_name is None:
  159. continue
  160. checkpoint_manager = CheckpointManager(
  161. role=local_role, party_id=local_party_id,
  162. model_id=model_id, model_version=model_version,
  163. component_name=row.f_component_name,
  164. )
  165. checkpoint_manager.load_checkpoints_from_disk()
  166. if checkpoint_manager.latest_checkpoint is not None:
  167. checkpoint_manager.deploy(
  168. child_model_version,
  169. row.f_model_alias,
  170. step_index,
  171. step_name,
  172. )
  173. deploy_model_info = gather_model_info_data(deploy_model)
  174. save_model_info(deploy_model_info)
  175. except Exception as e:
  176. stat_logger.exception(e)
  177. return 100, (
  178. f'deploy model of role {local_role} {local_party_id} failed, '
  179. f'details: {repr(e)}'
  180. )
  181. else:
  182. return 0, (
  183. f'deploy model of role {local_role} {local_party_id} success'
  184. + ('' if not warning_msg else f', warning: {warning_msg}')
  185. )