config_adapter.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from copy import deepcopy
  19. from fate_flow.entity import RunParameters
  20. class JobRuntimeConfigAdapter(object):
  21. def __init__(self, job_runtime_conf):
  22. job_runtime_conf = deepcopy(job_runtime_conf)
  23. if 'job_parameters' not in job_runtime_conf:
  24. job_runtime_conf['job_parameters'] = {}
  25. if 'common' not in job_runtime_conf['job_parameters']:
  26. job_runtime_conf['job_parameters']['common'] = {}
  27. self.job_runtime_conf = job_runtime_conf
  28. def get_common_parameters(self):
  29. if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
  30. job_parameters = RunParameters(**self.job_runtime_conf.get("job_parameters", {}).get("common", {}))
  31. self.job_runtime_conf['job_parameters']['common'] = job_parameters.to_dict()
  32. else:
  33. if "processors_per_node" in self.job_runtime_conf['job_parameters']:
  34. self.job_runtime_conf['job_parameters']["eggroll_run"] = \
  35. {"eggroll.session.processors.per.node": self.job_runtime_conf['job_parameters']["processors_per_node"]}
  36. job_parameters = RunParameters(**self.job_runtime_conf['job_parameters'])
  37. self.job_runtime_conf['job_parameters'] = job_parameters.to_dict()
  38. return job_parameters
  39. def update_common_parameters(self, common_parameters: RunParameters):
  40. if int(self.job_runtime_conf.get("dsl_version", 1)) == 2:
  41. self.job_runtime_conf["job_parameters"]["common"] = common_parameters.to_dict()
  42. else:
  43. self.job_runtime_conf["job_parameters"] = common_parameters.to_dict()
  44. return self.job_runtime_conf
  45. def get_job_parameters_dict(self, job_parameters: RunParameters = None):
  46. if job_parameters:
  47. if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
  48. self.job_runtime_conf['job_parameters']['common'] = job_parameters.to_dict()
  49. else:
  50. self.job_runtime_conf['job_parameters'] = job_parameters.to_dict()
  51. return self.job_runtime_conf['job_parameters']
  52. def check_removed_parameter(self):
  53. check_list = []
  54. if self.check_backend():
  55. check_list.append("backend")
  56. if self.check_work_mode():
  57. check_list.append("work_mode")
  58. return ','.join(check_list)
  59. def check_backend(self):
  60. if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
  61. backend = self.job_runtime_conf['job_parameters'].get('common', {}).get('backend')
  62. else:
  63. backend = self.job_runtime_conf['job_parameters'].get('backend')
  64. return backend is not None
  65. def check_work_mode(self):
  66. if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
  67. work_mode = self.job_runtime_conf['job_parameters'].get('common', {}).get('work_mode')
  68. else:
  69. work_mode = self.job_runtime_conf['job_parameters'].get('work_mode')
  70. return work_mode is not None
  71. def get_job_type(self):
  72. if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
  73. job_type = self.job_runtime_conf['job_parameters'].get('common', {}).get('job_type')
  74. if not job_type:
  75. job_type = self.job_runtime_conf['job_parameters'].get('job_type', 'train')
  76. else:
  77. job_type = self.job_runtime_conf['job_parameters'].get('job_type', 'train')
  78. return job_type
  79. def update_model_id_version(self, model_id=None, model_version=None):
  80. if int(self.job_runtime_conf.get('dsl_version', 1)) == 2:
  81. if model_id:
  82. self.job_runtime_conf['job_parameters']['common']['model_id'] = model_id
  83. if model_version:
  84. self.job_runtime_conf['job_parameters']['common']['model_version'] = model_version
  85. else:
  86. if model_id:
  87. self.job_runtime_conf['job_parameters']['model_id'] = model_id
  88. if model_version:
  89. self.job_runtime_conf['job_parameters']['model_version'] = model_version
  90. return self.job_runtime_conf