entity.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. from pipeline.utils.tools import extract_explicit_parameter
  18. class JobParameters(object):
  19. @extract_explicit_parameter
  20. def __init__(
  21. self,
  22. job_type="train",
  23. computing_engine=None,
  24. federation_engine=None,
  25. storage_engine=None,
  26. engines_address=None,
  27. federated_mode=None,
  28. federation_info=None,
  29. task_parallelism=None,
  30. federated_status_collect_type=None,
  31. federated_data_exchange_type=None,
  32. model_id=None,
  33. model_version=None,
  34. dsl_version=None,
  35. timeout=None,
  36. eggroll_run=None,
  37. spark_run=None,
  38. adaptation_parameters=None,
  39. **kwargs):
  40. explicit_parameters = kwargs["explict_parameters"]
  41. for param_key, param_value in explicit_parameters.items():
  42. if param_key == "backend":
  43. print("Please don't use parameter 'backend' in FATE version >= 1.7.")
  44. elif param_key == "work_mode":
  45. print("Please don't use parameter 'work_mode' in FATE version >= 1.7.")
  46. else:
  47. setattr(self, param_key, param_value)
  48. self.__party_instance = {}
  49. self._job_param = {}
  50. def get_party_instance(self, role="guest", party_id=None):
  51. if role not in ["guest", "host", "arbiter"]:
  52. raise ValueError("Role should be one of guest/host/arbiter")
  53. if party_id is not None:
  54. if isinstance(party_id, list):
  55. for _id in party_id:
  56. if not isinstance(_id, int) or _id <= 0:
  57. raise ValueError("party id should be positive integer")
  58. elif not isinstance(party_id, int) or party_id <= 0:
  59. raise ValueError("party id should be positive integer")
  60. if role not in self.__party_instance:
  61. self.__party_instance[role] = {}
  62. self.__party_instance[role]["party"] = {}
  63. party_key = party_id
  64. if isinstance(party_id, list):
  65. party_key = "|".join(map(str, party_id))
  66. if party_key not in self.__party_instance[role]["party"]:
  67. self.__party_instance[role]["party"][party_key] = None
  68. if not self.__party_instance[role]["party"][party_key]:
  69. party_instance = copy.deepcopy(self)
  70. self.__party_instance[role]["party"][party_key] = party_instance
  71. return self.__party_instance[role]["party"][party_key]
  72. def job_param(self, **kwargs):
  73. new_kwargs = copy.deepcopy(kwargs)
  74. for attr in new_kwargs:
  75. setattr(self, attr, new_kwargs[attr])
  76. self._job_param[attr] = new_kwargs[attr]
  77. def get_job_param(self):
  78. return self._job_param
  79. def get_common_param_conf(self):
  80. common_param_conf = {}
  81. for attr in self.__dict__:
  82. if attr.startswith("_"):
  83. continue
  84. common_param_conf[attr] = getattr(self, attr)
  85. return common_param_conf
  86. def get_role_param_conf(self, roles=None):
  87. role_param_conf = {}
  88. if not self.__party_instance:
  89. return role_param_conf
  90. for role in self.__party_instance:
  91. role_param_conf[role] = {}
  92. if None in self.__party_instance[role]["party"]:
  93. role_all_party_conf = self.__party_instance[role]["party"][None].get_job_param()
  94. if "all" not in role_param_conf:
  95. role_param_conf[role]["all"] = {}
  96. role_param_conf[role]["all"] = role_all_party_conf
  97. valid_partyids = roles.get(role)
  98. for party_id in self.__party_instance[role]["party"]:
  99. if not party_id:
  100. continue
  101. if isinstance(party_id, int):
  102. party_key = str(valid_partyids.index(party_id))
  103. else:
  104. party_list = list(map(int, party_id.split("|", -1)))
  105. party_key = "|".join(map(str, [valid_partyids.index(party) for party in party_list]))
  106. party_inst = self.__party_instance[role]["party"][party_id]
  107. if party_key not in role_param_conf:
  108. role_param_conf[role][party_key] = {}
  109. role_param_conf[role][party_key] = party_inst.get_job_param()
  110. return role_param_conf
  111. def get_config(self, *args, **kwargs):
  112. """need to implement"""
  113. roles = kwargs["roles"]
  114. common_param_conf = self.get_common_param_conf()
  115. role_param_conf = self.get_role_param_conf(roles)
  116. conf = {}
  117. if common_param_conf:
  118. conf['common'] = common_param_conf
  119. if role_param_conf:
  120. conf["role"] = role_param_conf
  121. return conf