_federation.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import concurrent.futures
  17. import os
  18. import signal
  19. from enum import Enum
  20. from eggroll.roll_pair.roll_pair import RollPair
  21. from eggroll.roll_site.roll_site import RollSiteContext
  22. from fate_arch.abc import FederationABC
  23. from fate_arch.common.log import getLogger
  24. from fate_arch.computing.eggroll import Table
  25. from fate_arch.common import remote_status
  26. LOGGER = getLogger()
  27. class Federation(FederationABC):
  28. def __init__(self, rp_ctx, rs_session_id, party, proxy_endpoint):
  29. LOGGER.debug(
  30. f"[federation.eggroll]init federation: "
  31. f"rp_session_id={rp_ctx.session_id}, rs_session_id={rs_session_id}, "
  32. f"party={party}, proxy_endpoint={proxy_endpoint}"
  33. )
  34. options = {
  35. "self_role": party.role,
  36. "self_party_id": party.party_id,
  37. "proxy_endpoint": proxy_endpoint,
  38. }
  39. self._session_id = rs_session_id
  40. self._rp_ctx = rp_ctx
  41. self._rsc = RollSiteContext(rs_session_id, rp_ctx=rp_ctx, options=options)
  42. LOGGER.debug(f"[federation.eggroll]init federation context done")
  43. @property
  44. def session_id(self) -> str:
  45. return self._session_id
  46. def get(self, name, tag, parties, gc):
  47. parties = [(party.role, party.party_id) for party in parties]
  48. raw_result = _get(name, tag, parties, self._rsc, gc)
  49. return [Table(v) if isinstance(v, RollPair) else v for v in raw_result]
  50. def remote(self, v, name, tag, parties, gc):
  51. if isinstance(v, Table):
  52. # noinspection PyProtectedMember
  53. v = v._rp
  54. parties = [(party.role, party.party_id) for party in parties]
  55. _remote(v, name, tag, parties, self._rsc, gc)
  56. def destroy(self, parties):
  57. self._rp_ctx.cleanup(name="*", namespace=self._session_id)
  58. def _remote(v, name, tag, parties, rsc, gc):
  59. log_str = f"federation.eggroll.remote.{name}.{tag}{parties})"
  60. if v is None:
  61. raise ValueError(f"[{log_str}]remote `None` to {parties}")
  62. if not _remote_tag_not_duplicate(name, tag, parties):
  63. raise ValueError(f"[{log_str}]remote to {parties} with duplicate tag")
  64. t = _get_type(v)
  65. if t == _FederationValueType.ROLL_PAIR:
  66. LOGGER.debug(
  67. f"[{log_str}]remote "
  68. f"RollPair(namespace={v.get_namespace()}, name={v.get_name()}, partitions={v.get_partitions()})"
  69. )
  70. gc.add_gc_action(tag, v, "destroy", {})
  71. _push_with_exception_handle(rsc, v, name, tag, parties)
  72. return
  73. if t == _FederationValueType.OBJECT:
  74. LOGGER.debug(f"[{log_str}]remote object with type: {type(v)}")
  75. _push_with_exception_handle(rsc, v, name, tag, parties)
  76. return
  77. raise NotImplementedError(f"t={t}")
  78. def _get(name, tag, parties, rsc, gc):
  79. rs = rsc.load(name=name, tag=tag)
  80. future_map = dict(zip(rs.pull(parties=parties), parties))
  81. rtn = {}
  82. for future in concurrent.futures.as_completed(future_map):
  83. party = future_map[future]
  84. v = future.result()
  85. rtn[party] = _get_value_post_process(v, name, tag, party, gc)
  86. return [rtn[party] for party in parties]
  87. class _FederationValueType(Enum):
  88. OBJECT = 1
  89. ROLL_PAIR = 2
  90. _remote_history = set()
  91. def _remote_tag_not_duplicate(name, tag, parties):
  92. for party in parties:
  93. if (name, tag, party) in _remote_history:
  94. return False
  95. _remote_history.add((name, tag, party))
  96. return True
  97. def _get_type(v):
  98. if isinstance(v, RollPair):
  99. return _FederationValueType.ROLL_PAIR
  100. return _FederationValueType.OBJECT
  101. def _push_with_exception_handle(rsc, v, name, tag, parties):
  102. def _remote_exception_re_raise(f, p):
  103. try:
  104. f.result()
  105. LOGGER.debug(
  106. f"[federation.eggroll.remote.{name}.{tag}]future to remote to party: {p} done"
  107. )
  108. except Exception as e:
  109. pid = os.getpid()
  110. LOGGER.exception(
  111. f"[federation.eggroll.remote.{name}.{tag}]future to remote to party: {p} fail,"
  112. f" terminating process(pid={pid})"
  113. )
  114. import traceback
  115. print(
  116. f"federation.eggroll.remote.{name}.{tag} future to remote to party: {p} fail,"
  117. f" terminating process {pid}, traceback: {traceback.format_exc()}"
  118. )
  119. os.kill(pid, signal.SIGTERM)
  120. raise e
  121. def _get_call_back_func(p):
  122. def _callback(f):
  123. return _remote_exception_re_raise(f, p)
  124. return _callback
  125. rs = rsc.load(name=name, tag=tag)
  126. futures = rs.push(obj=v, parties=parties)
  127. for party, future in zip(parties, futures):
  128. future.add_done_callback(_get_call_back_func(party))
  129. remote_status.add_remote_futures(futures)
  130. return rs
  131. _get_history = set()
  132. def _get_tag_not_duplicate(name, tag, party):
  133. if (name, tag, party) in _get_history:
  134. return False
  135. _get_history.add((name, tag, party))
  136. return True
  137. def _get_value_post_process(v, name, tag, party, gc):
  138. log_str = f"federation.eggroll.get.{name}.{tag}"
  139. if v is None:
  140. raise ValueError(f"[{log_str}]get `None` from {party}")
  141. if not _get_tag_not_duplicate(name, tag, party):
  142. raise ValueError(f"[{log_str}]get from {party} with duplicate tag")
  143. # got a roll pair
  144. if isinstance(v, RollPair):
  145. LOGGER.debug(
  146. f"[{log_str}] got "
  147. f"RollPair(namespace={v.get_namespace()}, name={v.get_name()}, partitions={v.get_partitions()})"
  148. )
  149. gc.add_gc_action(tag, v, "destroy", {})
  150. return v
  151. # others
  152. LOGGER.debug(f"[{log_str}] got object with type: {type(v)}")
  153. return v