transfer_variable.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import typing
  18. from typing import Union
  19. from fate_arch.common import Party, profile
  20. from fate_arch.common.log import getLogger
  21. from fate_arch.federation._gc import IterationGC
  22. __all__ = ["Variable", "BaseTransferVariables"]
  23. LOGGER = getLogger()
  24. class FederationTagNamespace(object):
  25. __namespace = "default"
  26. @classmethod
  27. def set_namespace(cls, namespace):
  28. cls.__namespace = namespace
  29. @classmethod
  30. def generate_tag(cls, *suffix):
  31. tags = (cls.__namespace, *map(str, suffix))
  32. return ".".join(tags)
  33. class Variable(object):
  34. """
  35. variable to distinguish federation by name
  36. """
  37. __instances: typing.MutableMapping[str, "Variable"] = {}
  38. @classmethod
  39. def get_or_create(
  40. cls, name, create_func: typing.Callable[[], "Variable"]
  41. ) -> "Variable":
  42. if name not in cls.__instances:
  43. value = create_func()
  44. cls.__instances[name] = value
  45. return cls.__instances[name]
  46. def __init__(
  47. self, name: str, src: typing.Tuple[str, ...], dst: typing.Tuple[str, ...]
  48. ):
  49. if name in self.__instances:
  50. raise RuntimeError(
  51. f"{self.__instances[name]} with {name} already initialized, which expected to be an singleton object."
  52. )
  53. assert (
  54. len(name.split(".")) >= 3
  55. ), "incorrect name format, should be `module_name.class_name.variable_name`"
  56. self._name = name
  57. self._src = src
  58. self._dst = dst
  59. self._get_gc = IterationGC()
  60. self._remote_gc = IterationGC()
  61. self._use_short_name = True
  62. self._short_name = self._get_short_name(self._name)
  63. @staticmethod
  64. def _get_short_name(name):
  65. fix_sized = hashlib.blake2b(name.encode("utf-8"), digest_size=10).hexdigest()
  66. _, right = name.rsplit(".", 1)
  67. return f"hash.{fix_sized}.{right}"
  68. # copy never create a new instance
  69. def __copy__(self):
  70. return self
  71. # deepcopy never create a new instance
  72. def __deepcopy__(self, memo):
  73. return self
  74. def set_preserve_num(self, n):
  75. self._get_gc.set_capacity(n)
  76. self._remote_gc.set_capacity(n)
  77. return self
  78. def disable_auto_clean(self):
  79. self._get_gc.disable()
  80. self._remote_gc.disable()
  81. return self
  82. def clean(self):
  83. self._get_gc.clean()
  84. self._remote_gc.clean()
  85. def remote_parties(
  86. self,
  87. obj,
  88. parties: Union[typing.List[Party], Party],
  89. suffix: Union[typing.Any, typing.Tuple] = tuple(),
  90. ):
  91. """
  92. remote object to specified parties
  93. Parameters
  94. ----------
  95. obj: object or table
  96. object or table to remote
  97. parties: typing.List[Party]
  98. parties to remote object/table to
  99. suffix: str or tuple of str
  100. suffix used to distinguish federation with in variable
  101. Returns
  102. -------
  103. None
  104. """
  105. from fate_arch.session import get_session
  106. session = get_session()
  107. if isinstance(parties, Party):
  108. parties = [parties]
  109. if not isinstance(suffix, tuple):
  110. suffix = (suffix,)
  111. tag = FederationTagNamespace.generate_tag(*suffix)
  112. for party in parties:
  113. if party.role not in self._dst:
  114. raise RuntimeError(
  115. f"not allowed to remote object to {party} using {self._name}"
  116. )
  117. local = session.parties.local_party.role
  118. if local not in self._src:
  119. raise RuntimeError(
  120. f"not allowed to remote object from {local} using {self._name}"
  121. )
  122. name = self._short_name if self._use_short_name else self._name
  123. timer = profile.federation_remote_timer(name, self._name, tag, local, parties)
  124. session.federation.remote(
  125. v=obj, name=name, tag=tag, parties=parties, gc=self._remote_gc
  126. )
  127. timer.done(session.federation)
  128. self._remote_gc.gc()
  129. def get_parties(
  130. self,
  131. parties: Union[typing.List[Party], Party],
  132. suffix: Union[typing.Any, typing.Tuple] = tuple(),
  133. ):
  134. """
  135. get objects/tables from specified parties
  136. Parameters
  137. ----------
  138. parties: typing.List[Party]
  139. parties to remote object/table to
  140. suffix: str or tuple of str
  141. suffix used to distinguish federation with in variable
  142. Returns
  143. -------
  144. list
  145. a list of objects/tables get from parties with same order of ``parties``
  146. """
  147. from fate_arch.session import get_session
  148. session = get_session()
  149. if not isinstance(parties, list):
  150. parties = [parties]
  151. if not isinstance(suffix, tuple):
  152. suffix = (suffix,)
  153. tag = FederationTagNamespace.generate_tag(*suffix)
  154. for party in parties:
  155. if party.role not in self._src:
  156. raise RuntimeError(
  157. f"not allowed to get object from {party} using {self._name}"
  158. )
  159. local = session.parties.local_party.role
  160. if local not in self._dst:
  161. raise RuntimeError(
  162. f"not allowed to get object to {local} using {self._name}"
  163. )
  164. name = self._short_name if self._use_short_name else self._name
  165. timer = profile.federation_get_timer(name, self._name, tag, local, parties)
  166. rtn = session.federation.get(
  167. name=name, tag=tag, parties=parties, gc=self._get_gc
  168. )
  169. timer.done(session.federation)
  170. self._get_gc.gc()
  171. return rtn
  172. def remote(self, obj, role=None, idx=-1, suffix=tuple()):
  173. """
  174. send obj to other parties.
  175. Args:
  176. obj: object to be sent
  177. role: role of parties to sent to, use one of ['Host', 'Guest', 'Arbiter', None].
  178. The default is None, means sent values to parties regardless their party role
  179. idx: id of party to sent to.
  180. The default is -1, which means sent values to parties regardless their party id
  181. suffix: additional tag suffix, the default is tuple()
  182. """
  183. from fate_arch.session import get_parties
  184. party_info = get_parties()
  185. if idx >= 0 and role is None:
  186. raise ValueError("role cannot be None if idx specified")
  187. # get subset of dst roles in runtime conf
  188. if role is None:
  189. parties = party_info.roles_to_parties(self._dst, strict=False)
  190. else:
  191. if isinstance(role, str):
  192. role = [role]
  193. parties = party_info.roles_to_parties(role)
  194. if idx >= 0:
  195. if idx >= len(parties):
  196. raise RuntimeError(
  197. f"try to remote to {idx}th party while only {len(parties)} configurated: {parties}, check {self._name}"
  198. )
  199. parties = parties[idx]
  200. return self.remote_parties(obj=obj, parties=parties, suffix=suffix)
  201. def get(self, idx=-1, role=None, suffix=tuple()):
  202. """
  203. get obj from other parties.
  204. Args:
  205. idx: id of party to get from.
  206. The default is -1, which means get values from parties regardless their party id
  207. suffix: additional tag suffix, the default is tuple()
  208. Returns:
  209. object or list of object
  210. """
  211. from fate_arch.session import get_parties
  212. if role is None:
  213. src_parties = get_parties().roles_to_parties(roles=self._src, strict=False)
  214. else:
  215. if isinstance(role, str):
  216. role = [role]
  217. src_parties = get_parties().roles_to_parties(roles=role, strict=False)
  218. if isinstance(idx, list):
  219. rtn = self.get_parties(parties=[src_parties[i] for i in idx], suffix=suffix)
  220. elif isinstance(idx, int):
  221. if idx < 0:
  222. rtn = self.get_parties(parties=src_parties, suffix=suffix)
  223. else:
  224. if idx >= len(src_parties):
  225. raise RuntimeError(
  226. f"try to get from {idx}th party while only {len(src_parties)} configurated: {src_parties}, check {self._name}"
  227. )
  228. rtn = self.get_parties(parties=src_parties[idx], suffix=suffix)[0]
  229. else:
  230. raise ValueError(
  231. f"illegal idx type: {type(idx)}, supported types: int or list of int"
  232. )
  233. return rtn
  234. class BaseTransferVariables(object):
  235. def __init__(self, *args):
  236. pass
  237. def __copy__(self):
  238. return self
  239. def __deepcopy__(self, memo):
  240. return self
  241. @staticmethod
  242. def set_flowid(flowid):
  243. """
  244. set global namespace for federations.
  245. Parameters
  246. ----------
  247. flowid: str
  248. namespace
  249. Returns
  250. -------
  251. None
  252. """
  253. FederationTagNamespace.set_namespace(str(flowid))
  254. def _create_variable(
  255. self, name: str, src: typing.Iterable[str], dst: typing.Iterable[str]
  256. ) -> Variable:
  257. full_name = f"{self.__module__}.{self.__class__.__name__}.{name}"
  258. return Variable.get_or_create(
  259. full_name, lambda: Variable(name=full_name, src=tuple(src), dst=tuple(dst))
  260. )
  261. @staticmethod
  262. def all_parties():
  263. """
  264. get all parties
  265. Returns
  266. -------
  267. list
  268. list of parties
  269. """
  270. from fate_arch.session import get_parties
  271. return get_parties().all_parties
  272. @staticmethod
  273. def local_party():
  274. """
  275. indicate local party
  276. Returns
  277. -------
  278. Party
  279. party this program running on
  280. """
  281. from fate_arch.session import get_parties
  282. return get_parties().local_party