_federation.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import typing
  2. from fate_arch._standalone import Federation as RawFederation, Table as RawTable
  3. from fate_arch.abc import FederationABC
  4. from fate_arch.abc import GarbageCollectionABC
  5. from fate_arch.common import Party, log
  6. from fate_arch.computing.standalone import Table
  7. LOGGER = log.getLogger()
  8. class Federation(FederationABC):
  9. def __init__(self, standalone_session, federation_session_id, party):
  10. LOGGER.debug(
  11. f"[federation.standalone]init federation: "
  12. f"standalone_session={standalone_session}, "
  13. f"federation_session_id={federation_session_id}, "
  14. f"party={party}"
  15. )
  16. self._session_id = federation_session_id
  17. self._federation = RawFederation(
  18. standalone_session, federation_session_id, party
  19. )
  20. LOGGER.debug("[federation.standalone]init federation context done")
  21. @property
  22. def session_id(self) -> str:
  23. return self._session_id
  24. def remote(
  25. self,
  26. v,
  27. name: str,
  28. tag: str,
  29. parties: typing.List[Party],
  30. gc: GarbageCollectionABC,
  31. ):
  32. if not _remote_tag_not_duplicate(name, tag, parties):
  33. raise ValueError(f"remote to {parties} with duplicate tag: {name}.{tag}")
  34. if isinstance(v, Table):
  35. # noinspection PyProtectedMember
  36. v = v._table
  37. return self._federation.remote(v=v, name=name, tag=tag, parties=parties)
  38. # noinspection PyProtectedMember
  39. def get(
  40. self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC
  41. ) -> typing.List:
  42. for party in parties:
  43. if not _get_tag_not_duplicate(name, tag, party):
  44. raise ValueError(f"get from {party} with duplicate tag: {name}.{tag}")
  45. rtn = self._federation.get(name=name, tag=tag, parties=parties)
  46. return [Table(r) if isinstance(r, RawTable) else r for r in rtn]
  47. def destroy(self, parties):
  48. self._federation.destroy()
  49. _remote_history = set()
  50. def _remote_tag_not_duplicate(name, tag, parties):
  51. for party in parties:
  52. if (name, tag, party) in _remote_history:
  53. return False
  54. _remote_history.add((name, tag, party))
  55. return True
  56. _get_history = set()
  57. def _get_tag_not_duplicate(name, tag, party):
  58. if (name, tag, party) in _get_history:
  59. return False
  60. _get_history.add((name, tag, party))
  61. return True