123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from copy import deepcopy
- from fate_flow.entity import RunParameters
- class JobRuntimeConfigAdapter(object):
- def __init__(self, job_runtime_conf):
- job_runtime_conf = deepcopy(job_runtime_conf)
- if 'job_parameters' not in job_runtime_conf:
- job_runtime_conf['job_parameters'] = {}
- if 'common' not in job_runtime_conf['job_parameters']:
- job_runtime_conf['job_parameters']['common'] = {}
- self.job_runtime_conf = job_runtime_conf
- def get_common_parameters(self):
- if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
- job_parameters = RunParameters(**self.job_runtime_conf.get("job_parameters", {}).get("common", {}))
- self.job_runtime_conf['job_parameters']['common'] = job_parameters.to_dict()
- else:
- if "processors_per_node" in self.job_runtime_conf['job_parameters']:
- self.job_runtime_conf['job_parameters']["eggroll_run"] = \
- {"eggroll.session.processors.per.node": self.job_runtime_conf['job_parameters']["processors_per_node"]}
- job_parameters = RunParameters(**self.job_runtime_conf['job_parameters'])
- self.job_runtime_conf['job_parameters'] = job_parameters.to_dict()
- return job_parameters
- def update_common_parameters(self, common_parameters: RunParameters):
- if int(self.job_runtime_conf.get("dsl_version", 1)) == 2:
- self.job_runtime_conf["job_parameters"]["common"] = common_parameters.to_dict()
- else:
- self.job_runtime_conf["job_parameters"] = common_parameters.to_dict()
- return self.job_runtime_conf
- def get_job_parameters_dict(self, job_parameters: RunParameters = None):
- if job_parameters:
- if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
- self.job_runtime_conf['job_parameters']['common'] = job_parameters.to_dict()
- else:
- self.job_runtime_conf['job_parameters'] = job_parameters.to_dict()
- return self.job_runtime_conf['job_parameters']
- def check_removed_parameter(self):
- check_list = []
- if self.check_backend():
- check_list.append("backend")
- if self.check_work_mode():
- check_list.append("work_mode")
- return ','.join(check_list)
- def check_backend(self):
- if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
- backend = self.job_runtime_conf['job_parameters'].get('common', {}).get('backend')
- else:
- backend = self.job_runtime_conf['job_parameters'].get('backend')
- return backend is not None
- def check_work_mode(self):
- if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
- work_mode = self.job_runtime_conf['job_parameters'].get('common', {}).get('work_mode')
- else:
- work_mode = self.job_runtime_conf['job_parameters'].get('work_mode')
- return work_mode is not None
- def get_job_type(self):
- if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
- job_type = self.job_runtime_conf['job_parameters'].get('common', {}).get('job_type')
- if not job_type:
- job_type = self.job_runtime_conf['job_parameters'].get('job_type', 'train')
- else:
- job_type = self.job_runtime_conf['job_parameters'].get('job_type', 'train')
- return job_type
- def update_model_id_version(self, model_id=None, model_version=None):
- if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
- if model_id:
- self.job_runtime_conf['job_parameters']['common']['model_id'] = model_id
- if model_version:
- self.job_runtime_conf['job_parameters']['common']['model_version'] = model_version
- else:
- if model_id:
- self.job_runtime_conf['job_parameters']['model_id'] = model_id
- if model_version:
- self.job_runtime_conf['job_parameters']['model_version'] = model_version
- return self.job_runtime_conf
|