_parties.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import typing
  17. from fate_arch.common import Party
  18. class Role:
  19. def __init__(self, parties) -> None:
  20. self._parties = parties
  21. self._size = len(self._parties)
  22. def __getitem__(self, key):
  23. return self._parties[key]
  24. class _PartiesMeta(type):
  25. @property
  26. def Guest(cls) -> Role:
  27. return cls._get_instance()._guest
  28. @property
  29. def Host(cls) -> Role:
  30. return cls._get_instance()._host
  31. @property
  32. def Arbiter(cls) -> Role:
  33. return cls._get_instance()._arbiter
  34. class PartiesInfo(metaclass=_PartiesMeta):
  35. _instance = None
  36. @classmethod
  37. def _set_instance(cls, inst):
  38. cls._instance = inst
  39. @classmethod
  40. def _get_instance(cls) -> "PartiesInfo":
  41. if cls._instance is None:
  42. raise RuntimeError(f"parties not initialized")
  43. return cls._instance
  44. @classmethod
  45. def get_parties(cls, parties) -> typing.List[Party]:
  46. if isinstance(parties, Party):
  47. return [parties]
  48. elif isinstance(parties, Role):
  49. return parties[:]
  50. elif isinstance(parties, list):
  51. plain_parties = []
  52. for p in parties:
  53. plain_parties.extend(cls.get_parties(p))
  54. if len(set(plain_parties)) != len(plain_parties):
  55. raise ValueError(f"duplicated parties exsits: {plain_parties}")
  56. return plain_parties
  57. raise ValueError(f"unsupported type: {type(parties)}")
  58. @staticmethod
  59. def from_conf(conf: typing.MutableMapping[str, dict]):
  60. try:
  61. local = Party(
  62. role=conf["local"]["role"], party_id=conf["local"]["party_id"]
  63. )
  64. role_to_parties = {}
  65. for role, party_id_list in conf.get("role", {}).items():
  66. role_to_parties[role] = [
  67. Party(role=role, party_id=party_id) for party_id in party_id_list
  68. ]
  69. except Exception as e:
  70. raise RuntimeError(
  71. "conf parse error, a correct configuration could be:\n"
  72. "{\n"
  73. " 'local': {'role': 'guest', 'party_id': 10000},\n"
  74. " 'role': {'guest': [10000], 'host': [9999, 9998]}, 'arbiter': [9997]}\n"
  75. "}"
  76. ) from e
  77. return PartiesInfo(local, role_to_parties)
  78. def __init__(
  79. self,
  80. local: Party,
  81. role_to_parties: typing.MutableMapping[str, typing.List[Party]],
  82. ):
  83. self._local = local
  84. self._role_to_parties = role_to_parties
  85. self._guest = Role(role_to_parties.get("guest", []))
  86. self._host = Role(role_to_parties.get("host", []))
  87. self._arbiter = Role(role_to_parties.get("arbiter", []))
  88. self._set_instance(self)
  89. @property
  90. def local_party(self) -> Party:
  91. return self._local
  92. @property
  93. def all_parties(self):
  94. return [
  95. party for parties in self._role_to_parties.values() for party in parties
  96. ]
  97. @property
  98. def role_set(self):
  99. return set(self._role_to_parties)
  100. def roles_to_parties(self, roles: typing.Iterable, strict=True) -> list:
  101. parties = []
  102. for role in roles:
  103. if role not in self._role_to_parties:
  104. if strict:
  105. raise RuntimeError(
  106. f"try to get role {role} "
  107. f"which is not configured in `role` in runtime conf({self._role_to_parties})"
  108. )
  109. else:
  110. continue
  111. parties.extend(self._role_to_parties[role])
  112. return parties
  113. def role_to_party(self, role, idx) -> Party:
  114. return self._role_to_parties[role][idx]
  115. __all__ = ["PartiesInfo", "Role"]