123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from fate_arch.common.base_utils import json_dumps, json_loads
- from fate_flow.db.db_models import DB, MachineLearningModelInfo as MLModel, PipelineComponentMeta
- from fate_flow.model.sync_model import SyncModel
- from fate_flow.pipelined_model import pipelined_model
- from fate_flow.scheduler.cluster_scheduler import ClusterScheduler
- from fate_flow.settings import ENABLE_MODEL_STORE, stat_logger
- from fate_flow.utils.base_utils import compare_version
- from fate_flow.utils.config_adapter import JobRuntimeConfigAdapter
- from fate_flow.utils.job_utils import PIPELINE_COMPONENT_NAME
- from fate_flow.utils.model_utils import (
- gather_model_info_data, gen_model_id,
- gen_party_model_id, save_model_info,
- )
- def compare_roles(request_conf_roles: dict, run_time_conf_roles: dict):
- if request_conf_roles.keys() == run_time_conf_roles.keys():
- verify_format = True
- verify_equality = True
- for key in request_conf_roles.keys():
- verify_format = (
- verify_format and
- len(request_conf_roles[key]) == len(run_time_conf_roles[key]) and
- isinstance(request_conf_roles[key], list)
- )
- request_conf_roles_set = set(str(item) for item in request_conf_roles[key])
- run_time_conf_roles_set = set(str(item) for item in run_time_conf_roles[key])
- verify_equality = verify_equality and (request_conf_roles_set == run_time_conf_roles_set)
- if not verify_format:
- raise Exception("The structure of roles data of local configuration is different from "
- "model runtime configuration's. Migration aborting.")
- else:
- return verify_equality
- raise Exception("The structure of roles data of local configuration is different from "
- "model runtime configuration's. Migration aborting.")
- def migration(config_data: dict):
- model_id = config_data['model_id']
- model_version = config_data['model_version']
- local_role = config_data['local']['role']
- local_party_id = config_data['local']['party_id']
- new_party_id = config_data["local"]["migrate_party_id"]
- new_model_id = gen_model_id(config_data["migrate_role"])
- unify_model_version = config_data['unify_model_version']
- try:
- if ENABLE_MODEL_STORE:
- sync_model = SyncModel(
- role=local_role, party_id=local_party_id,
- model_id=model_id, model_version=model_version,
- )
- if sync_model.remote_exists():
- sync_model.download(True)
- party_model_id = gen_party_model_id(
- model_id=model_id,
- role=local_role,
- party_id=local_party_id,
- )
- source_model = pipelined_model.PipelinedModel(party_model_id, model_version)
- if not source_model.exists():
- raise FileNotFoundError(f"Can not found {model_id} {model_version} model local cache.")
- with DB.connection_context():
- if MLModel.get_or_none(
- MLModel.f_role == local_role,
- MLModel.f_party_id == new_party_id,
- MLModel.f_model_id == new_model_id,
- MLModel.f_model_version == unify_model_version,
- ):
- raise FileExistsError(
- f"Unify model version {unify_model_version} has been occupied in database. "
- "Please choose another unify model version and try again."
- )
- migrate_tool = source_model.get_model_migrate_tool()
- migrate_model = pipelined_model.PipelinedModel(
- gen_party_model_id(
- model_id=new_model_id,
- role=local_role,
- party_id=new_party_id,
- ),
- unify_model_version,
- )
- query = source_model.pipelined_component.get_define_meta_from_db(
- PipelineComponentMeta.f_component_name != PIPELINE_COMPONENT_NAME,
- )
- for row in query:
- buffer_obj = source_model.read_component_model(row.f_component_name, row.f_model_alias)
- modified_buffer = migrate_tool.model_migration(
- model_contents=buffer_obj,
- module_name=row.f_component_module_name,
- old_guest_list=config_data['role']['guest'],
- new_guest_list=config_data['migrate_role']['guest'],
- old_host_list=config_data['role']['host'],
- new_host_list=config_data['migrate_role']['host'],
- old_arbiter_list=config_data.get('role', {}).get('arbiter', None),
- new_arbiter_list=config_data.get('migrate_role', {}).get('arbiter', None),
- )
- migrate_model.save_component_model(
- row.f_component_name, row.f_component_module_name,
- row.f_model_alias, modified_buffer, row.f_run_parameters,
- )
- pipeline_model = source_model.read_pipeline_model()
- pipeline_model.model_id = new_model_id
- pipeline_model.model_version = unify_model_version
- pipeline_model.roles = json_dumps(config_data['migrate_role'], byte=True)
- train_runtime_conf = json_loads(pipeline_model.train_runtime_conf)
- train_runtime_conf['role'] = config_data['migrate_role']
- train_runtime_conf['initiator'] = config_data['migrate_initiator']
- train_runtime_conf = JobRuntimeConfigAdapter(
- train_runtime_conf,
- ).update_model_id_version(
- model_id=new_model_id,
- model_version=unify_model_version,
- )
- pipeline_model.train_runtime_conf = json_dumps(train_runtime_conf, byte=True)
- if compare_version(pipeline_model.fate_version, '1.5.0') == 'gt':
- pipeline_model.initiator_role = config_data["migrate_initiator"]['role']
- pipeline_model.initiator_party_id = config_data["migrate_initiator"]['party_id']
- runtime_conf_on_party = json_loads(pipeline_model.runtime_conf_on_party)
- runtime_conf_on_party['role'] = config_data['migrate_role']
- runtime_conf_on_party['initiator'] = config_data['migrate_initiator']
- runtime_conf_on_party['job_parameters']['model_id'] = new_model_id
- runtime_conf_on_party['job_parameters']['model_version'] = unify_model_version
- pipeline_model.runtime_conf_on_party = json_dumps(runtime_conf_on_party, byte=True)
- migrate_model.save_pipeline_model(pipeline_model)
- migrate_model_info = gather_model_info_data(migrate_model)
- save_model_info(migrate_model_info)
- ClusterScheduler.cluster_command('/model/archive/packaging', {
- 'party_model_id': migrate_model.party_model_id,
- 'model_version': migrate_model.model_version,
- })
- return (0, (
- "Migrating model successfully. The configuration of model has been modified automatically. "
- f"New model id is: {migrate_model._model_id}, model version is: {migrate_model.model_version}. "
- f"Model files can be found at '{migrate_model.archive_model_file_path}'."
- ), {
- "model_id": migrate_model.party_model_id,
- "model_version": migrate_model.model_version,
- "path": migrate_model.archive_model_file_path,
- })
- except Exception as e:
- stat_logger.exception(e)
- return 100, str(e), {}
|