migrate_model.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_arch.common.base_utils import json_dumps, json_loads
  17. from fate_flow.db.db_models import DB, MachineLearningModelInfo as MLModel, PipelineComponentMeta
  18. from fate_flow.model.sync_model import SyncModel
  19. from fate_flow.pipelined_model import pipelined_model
  20. from fate_flow.scheduler.cluster_scheduler import ClusterScheduler
  21. from fate_flow.settings import ENABLE_MODEL_STORE, stat_logger
  22. from fate_flow.utils.base_utils import compare_version
  23. from fate_flow.utils.config_adapter import JobRuntimeConfigAdapter
  24. from fate_flow.utils.job_utils import PIPELINE_COMPONENT_NAME
  25. from fate_flow.utils.model_utils import (
  26. gather_model_info_data, gen_model_id,
  27. gen_party_model_id, save_model_info,
  28. )
  29. def compare_roles(request_conf_roles: dict, run_time_conf_roles: dict):
  30. if request_conf_roles.keys() == run_time_conf_roles.keys():
  31. verify_format = True
  32. verify_equality = True
  33. for key in request_conf_roles.keys():
  34. verify_format = (
  35. verify_format and
  36. len(request_conf_roles[key]) == len(run_time_conf_roles[key]) and
  37. isinstance(request_conf_roles[key], list)
  38. )
  39. request_conf_roles_set = set(str(item) for item in request_conf_roles[key])
  40. run_time_conf_roles_set = set(str(item) for item in run_time_conf_roles[key])
  41. verify_equality = verify_equality and (request_conf_roles_set == run_time_conf_roles_set)
  42. if not verify_format:
  43. raise Exception("The structure of roles data of local configuration is different from "
  44. "model runtime configuration's. Migration aborting.")
  45. else:
  46. return verify_equality
  47. raise Exception("The structure of roles data of local configuration is different from "
  48. "model runtime configuration's. Migration aborting.")
  49. def migration(config_data: dict):
  50. model_id = config_data['model_id']
  51. model_version = config_data['model_version']
  52. local_role = config_data['local']['role']
  53. local_party_id = config_data['local']['party_id']
  54. new_party_id = config_data["local"]["migrate_party_id"]
  55. new_model_id = gen_model_id(config_data["migrate_role"])
  56. unify_model_version = config_data['unify_model_version']
  57. try:
  58. if ENABLE_MODEL_STORE:
  59. sync_model = SyncModel(
  60. role=local_role, party_id=local_party_id,
  61. model_id=model_id, model_version=model_version,
  62. )
  63. if sync_model.remote_exists():
  64. sync_model.download(True)
  65. party_model_id = gen_party_model_id(
  66. model_id=model_id,
  67. role=local_role,
  68. party_id=local_party_id,
  69. )
  70. source_model = pipelined_model.PipelinedModel(party_model_id, model_version)
  71. if not source_model.exists():
  72. raise FileNotFoundError(f"Can not found {model_id} {model_version} model local cache.")
  73. with DB.connection_context():
  74. if MLModel.get_or_none(
  75. MLModel.f_role == local_role,
  76. MLModel.f_party_id == new_party_id,
  77. MLModel.f_model_id == new_model_id,
  78. MLModel.f_model_version == unify_model_version,
  79. ):
  80. raise FileExistsError(
  81. f"Unify model version {unify_model_version} has been occupied in database. "
  82. "Please choose another unify model version and try again."
  83. )
  84. migrate_tool = source_model.get_model_migrate_tool()
  85. migrate_model = pipelined_model.PipelinedModel(
  86. gen_party_model_id(
  87. model_id=new_model_id,
  88. role=local_role,
  89. party_id=new_party_id,
  90. ),
  91. unify_model_version,
  92. )
  93. query = source_model.pipelined_component.get_define_meta_from_db(
  94. PipelineComponentMeta.f_component_name != PIPELINE_COMPONENT_NAME,
  95. )
  96. for row in query:
  97. buffer_obj = source_model.read_component_model(row.f_component_name, row.f_model_alias)
  98. modified_buffer = migrate_tool.model_migration(
  99. model_contents=buffer_obj,
  100. module_name=row.f_component_module_name,
  101. old_guest_list=config_data['role']['guest'],
  102. new_guest_list=config_data['migrate_role']['guest'],
  103. old_host_list=config_data['role']['host'],
  104. new_host_list=config_data['migrate_role']['host'],
  105. old_arbiter_list=config_data.get('role', {}).get('arbiter', None),
  106. new_arbiter_list=config_data.get('migrate_role', {}).get('arbiter', None),
  107. )
  108. migrate_model.save_component_model(
  109. row.f_component_name, row.f_component_module_name,
  110. row.f_model_alias, modified_buffer, row.f_run_parameters,
  111. )
  112. pipeline_model = source_model.read_pipeline_model()
  113. pipeline_model.model_id = new_model_id
  114. pipeline_model.model_version = unify_model_version
  115. pipeline_model.roles = json_dumps(config_data['migrate_role'], byte=True)
  116. train_runtime_conf = json_loads(pipeline_model.train_runtime_conf)
  117. train_runtime_conf['role'] = config_data['migrate_role']
  118. train_runtime_conf['initiator'] = config_data['migrate_initiator']
  119. train_runtime_conf = JobRuntimeConfigAdapter(
  120. train_runtime_conf,
  121. ).update_model_id_version(
  122. model_id=new_model_id,
  123. model_version=unify_model_version,
  124. )
  125. pipeline_model.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)
  126. if compare_version(pipeline_model.fate_version, '1.5.0') == 'gt':
  127. pipeline_model.initiator_role = config_data["migrate_initiator"]['role']
  128. pipeline_model.initiator_party_id = config_data["migrate_initiator"]['party_id']
  129. runtime_conf_on_party = json_loads(pipeline_model.runtime_conf_on_party)
  130. runtime_conf_on_party['role'] = config_data['migrate_role']
  131. runtime_conf_on_party['initiator'] = config_data['migrate_initiator']
  132. runtime_conf_on_party['job_parameters']['model_id'] = new_model_id
  133. runtime_conf_on_party['job_parameters']['model_version'] = unify_model_version
  134. pipeline_model.runtime_conf_on_party = json_dumps(runtime_conf_on_party, byte=True)
  135. migrate_model.save_pipeline_model(pipeline_model)
  136. migrate_model_info = gather_model_info_data(migrate_model)
  137. save_model_info(migrate_model_info)
  138. ClusterScheduler.cluster_command('/model/archive/packaging', {
  139. 'party_model_id': migrate_model.party_model_id,
  140. 'model_version': migrate_model.model_version,
  141. })
  142. return (0, (
  143. "Migrating model successfully. The configuration of model has been modified automatically. "
  144. f"New model id is: {migrate_model._model_id}, model version is: {migrate_model.model_version}. "
  145. f"Model files can be found at '{migrate_model.archive_model_file_path}'."
  146. ), {
  147. "model_id": migrate_model.party_model_id,
  148. "model_version": migrate_model.model_version,
  149. "path": migrate_model.archive_model_file_path,
  150. })
  151. except Exception as e:
  152. stat_logger.exception(e)
  153. return 100, str(e), {}