_federation.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. #
  2. # Copyright 2022 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import sys
  18. import typing
  19. from pickle import dumps as p_dumps, loads as p_loads
  20. from fate_arch.abc import CTableABC
  21. from fate_arch.abc import FederationABC, GarbageCollectionABC
  22. from fate_arch.common import Party
  23. from fate_arch.common.log import getLogger
  24. from fate_arch.federation import FederationDataType
  25. from fate_arch.federation._datastream import Datastream
  26. from fate_arch.session import computing_session
  27. LOGGER = getLogger()
  28. NAME_DTYPE_TAG = "<dtype>"
  29. _SPLIT_ = "^"
  30. def _get_splits(obj, max_message_size):
  31. obj_bytes = p_dumps(obj, protocol=4)
  32. byte_size = len(obj_bytes)
  33. num_slice = (byte_size - 1) // max_message_size + 1
  34. if num_slice <= 1:
  35. return obj, num_slice
  36. else:
  37. _max_size = max_message_size
  38. kv = [(i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)]
  39. return kv, num_slice
  40. class FederationBase(FederationABC):
  41. @staticmethod
  42. def from_conf(
  43. federation_session_id: str,
  44. party: Party,
  45. runtime_conf: dict,
  46. **kwargs
  47. ):
  48. raise NotImplementedError()
  49. def __init__(
  50. self,
  51. session_id,
  52. party: Party,
  53. mq,
  54. max_message_size,
  55. conf=None
  56. ):
  57. self._session_id = session_id
  58. self._party = party
  59. self._mq = mq
  60. self._topic_map = {}
  61. self._channels_map = {}
  62. self._name_dtype_map = {}
  63. self._message_cache = {}
  64. self._max_message_size = max_message_size
  65. self._conf = conf
  66. def __getstate__(self):
  67. pass
  68. @property
  69. def session_id(self) -> str:
  70. return self._session_id
  71. def destroy(self, parties):
  72. raise NotImplementedError()
  73. def get(
  74. self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC
  75. ) -> typing.List:
  76. log_str = f"[federation.get](name={name}, tag={tag}, parties={parties})"
  77. LOGGER.debug(f"[{log_str}]start to get")
  78. _name_dtype_keys = [
  79. _SPLIT_.join([party.role, party.party_id, name, tag, "get"])
  80. for party in parties
  81. ]
  82. if _name_dtype_keys[0] not in self._name_dtype_map:
  83. party_topic_infos = self._get_party_topic_infos(parties, dtype=NAME_DTYPE_TAG)
  84. channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
  85. rtn_dtype = []
  86. for i, info in enumerate(channel_infos):
  87. obj = self._receive_obj(
  88. info, name, tag=_SPLIT_.join([tag, NAME_DTYPE_TAG])
  89. )
  90. rtn_dtype.append(obj)
  91. LOGGER.debug(
  92. f"[federation.get] _name_dtype_keys: {_name_dtype_keys}, dtype: {obj}"
  93. )
  94. for k in _name_dtype_keys:
  95. if k not in self._name_dtype_map:
  96. self._name_dtype_map[k] = rtn_dtype[0]
  97. rtn_dtype = self._name_dtype_map[_name_dtype_keys[0]]
  98. rtn = []
  99. dtype = rtn_dtype.get("dtype", None)
  100. partitions = rtn_dtype.get("partitions", None)
  101. if dtype == FederationDataType.TABLE or dtype == FederationDataType.SPLIT_OBJECT:
  102. party_topic_infos = self._get_party_topic_infos(parties, name, partitions=partitions)
  103. for i in range(len(party_topic_infos)):
  104. party = parties[i]
  105. role = party.role
  106. party_id = party.party_id
  107. topic_infos = party_topic_infos[i]
  108. receive_func = self._get_partition_receive_func(
  109. name=name,
  110. tag=tag,
  111. src_party_id=self._party.party_id,
  112. src_role=self._party.role,
  113. dst_party_id=party_id,
  114. dst_role=role,
  115. topic_infos=topic_infos,
  116. mq=self._mq,
  117. conf=self._conf
  118. )
  119. table = computing_session.parallelize(range(partitions), partitions, include_key=False)
  120. table = table.mapPartitionsWithIndex(receive_func)
  121. # add gc
  122. gc.add_gc_action(tag, table, "__del__", {})
  123. LOGGER.debug(
  124. f"[{log_str}]received table({i + 1}/{len(parties)}), party: {parties[i]} "
  125. )
  126. if dtype == FederationDataType.TABLE:
  127. rtn.append(table)
  128. else:
  129. obj_bytes = b''.join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])))
  130. obj = p_loads(obj_bytes)
  131. rtn.append(obj)
  132. else:
  133. party_topic_infos = self._get_party_topic_infos(parties, name)
  134. channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
  135. for i, info in enumerate(channel_infos):
  136. obj = self._receive_obj(info, name, tag)
  137. LOGGER.debug(
  138. f"[{log_str}]received obj({i + 1}/{len(parties)}), party: {parties[i]} "
  139. )
  140. rtn.append(obj)
  141. LOGGER.debug(f"[{log_str}]finish to get")
  142. return rtn
  143. def remote(
  144. self,
  145. v,
  146. name: str,
  147. tag: str,
  148. parties: typing.List[Party],
  149. gc: GarbageCollectionABC,
  150. ) -> typing.NoReturn:
  151. log_str = f"[federation.remote](name={name}, tag={tag}, parties={parties})"
  152. _name_dtype_keys = [
  153. _SPLIT_.join([party.role, party.party_id, name, tag, "remote"])
  154. for party in parties
  155. ]
  156. if _name_dtype_keys[0] not in self._name_dtype_map:
  157. party_topic_infos = self._get_party_topic_infos(parties, dtype=NAME_DTYPE_TAG)
  158. channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
  159. if not isinstance(v, CTableABC):
  160. v, num_slice = _get_splits(v, self._max_message_size)
  161. if num_slice > 1:
  162. v = computing_session.parallelize(data=v, partition=1, include_key=True)
  163. body = {"dtype": FederationDataType.SPLIT_OBJECT, "partitions": v.partitions}
  164. else:
  165. body = {"dtype": FederationDataType.OBJECT}
  166. else:
  167. body = {"dtype": FederationDataType.TABLE, "partitions": v.partitions}
  168. LOGGER.debug(
  169. f"[federation.remote] _name_dtype_keys: {_name_dtype_keys}, dtype: {body}"
  170. )
  171. self._send_obj(
  172. name=name,
  173. tag=_SPLIT_.join([tag, NAME_DTYPE_TAG]),
  174. data=p_dumps(body),
  175. channel_infos=channel_infos,
  176. )
  177. for k in _name_dtype_keys:
  178. if k not in self._name_dtype_map:
  179. self._name_dtype_map[k] = body
  180. if isinstance(v, CTableABC):
  181. total_size = v.count()
  182. partitions = v.partitions
  183. LOGGER.debug(
  184. f"[{log_str}]start to remote table, total_size={total_size}, partitions={partitions}"
  185. )
  186. party_topic_infos = self._get_party_topic_infos(parties, name, partitions=partitions)
  187. # add gc
  188. gc.add_gc_action(tag, v, "__del__", {})
  189. send_func = self._get_partition_send_func(
  190. name=name,
  191. tag=tag,
  192. partitions=partitions,
  193. party_topic_infos=party_topic_infos,
  194. src_party_id=self._party.party_id,
  195. src_role=self._party.role,
  196. mq=self._mq,
  197. max_message_size=self._max_message_size,
  198. conf=self._conf
  199. )
  200. # noinspection PyProtectedMember
  201. v.mapPartitionsWithIndex(send_func)
  202. else:
  203. LOGGER.debug(f"[{log_str}]start to remote obj")
  204. party_topic_infos = self._get_party_topic_infos(parties, name)
  205. channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
  206. self._send_obj(
  207. name=name, tag=tag, data=p_dumps(v), channel_infos=channel_infos
  208. )
  209. LOGGER.debug(f"[{log_str}]finish to remote")
  210. def _get_party_topic_infos(
  211. self, parties: typing.List[Party], name=None, partitions=None, dtype=None
  212. ) -> typing.List:
  213. topic_infos = [
  214. self._get_or_create_topic(party, name, partitions, dtype)
  215. for party in parties
  216. ]
  217. return topic_infos
  218. def _maybe_create_topic_and_replication(self, party, topic_suffix):
  219. # gen names
  220. raise NotImplementedError()
  221. def _get_or_create_topic(
  222. self, party: Party, name=None, partitions=None, dtype=None
  223. ) -> typing.Tuple:
  224. topic_key_list = []
  225. topic_infos = []
  226. if dtype is not None:
  227. topic_key = _SPLIT_.join(
  228. [party.role, party.party_id, dtype, dtype])
  229. topic_key_list.append(topic_key)
  230. else:
  231. if partitions is not None:
  232. for i in range(partitions):
  233. topic_key = _SPLIT_.join(
  234. [party.role, party.party_id, name, str(i)])
  235. topic_key_list.append(topic_key)
  236. elif name is not None:
  237. topic_key = _SPLIT_.join([party.role, party.party_id, name])
  238. topic_key_list.append(topic_key)
  239. else:
  240. topic_key = _SPLIT_.join([party.role, party.party_id])
  241. topic_key_list.append(topic_key)
  242. for topic_key in topic_key_list:
  243. if topic_key not in self._topic_map:
  244. topic_key_splits = topic_key.split(_SPLIT_)
  245. topic_suffix = "-".join(topic_key_splits[2:])
  246. topic_pair = self._maybe_create_topic_and_replication(party, topic_suffix)
  247. self._topic_map[topic_key] = topic_pair
  248. topic_pair = self._topic_map[topic_key]
  249. topic_infos.append((topic_key, topic_pair))
  250. return topic_infos
  251. def _get_channel(
  252. self, topic_pair, src_party_id, src_role, dst_party_id, dst_role, mq=None, conf: dict = None):
  253. raise NotImplementedError()
  254. def _get_channels(self, party_topic_infos):
  255. channel_infos = []
  256. for e in party_topic_infos:
  257. for topic_key, topic_pair in e:
  258. topic_key_splits = topic_key.split(_SPLIT_)
  259. role = topic_key_splits[0]
  260. party_id = topic_key_splits[1]
  261. info = self._channels_map.get(topic_key)
  262. if info is None:
  263. info = self._get_channel(
  264. topic_pair=topic_pair,
  265. src_party_id=self._party.party_id,
  266. src_role=self._party.role,
  267. dst_party_id=party_id,
  268. dst_role=role,
  269. mq=self._mq,
  270. conf=self._conf
  271. )
  272. self._channels_map[topic_key] = info
  273. channel_infos.append(info)
  274. return channel_infos
  275. def _get_channels_index(self, index, party_topic_infos, src_party_id, src_role, mq=None, conf: dict = None):
  276. channel_infos = []
  277. for e in party_topic_infos:
  278. # select specified topic_info for a party
  279. topic_key, topic_pair = e[index]
  280. topic_key_splits = topic_key.split(_SPLIT_)
  281. role = topic_key_splits[0]
  282. party_id = topic_key_splits[1]
  283. info = self._get_channel(
  284. topic_pair=topic_pair,
  285. src_party_id=src_party_id,
  286. src_role=src_role,
  287. dst_party_id=party_id,
  288. dst_role=role,
  289. mq=mq,
  290. conf=conf
  291. )
  292. channel_infos.append(info)
  293. return channel_infos
  294. def _send_obj(self, name, tag, data, channel_infos):
  295. for info in channel_infos:
  296. properties = {
  297. "content_type": "text/plain",
  298. "app_id": info._dst_party_id,
  299. "message_id": name,
  300. "correlation_id": tag
  301. }
  302. LOGGER.debug(f"[federation._send_obj]properties:{properties}.")
  303. info.produce(body=data, properties=properties)
  304. def _send_kv(
  305. self, name, tag, data, channel_infos, partition_size, partitions, message_key
  306. ):
  307. headers = json.dumps(
  308. {
  309. "partition_size": partition_size,
  310. "partitions": partitions,
  311. "message_key": message_key
  312. }
  313. )
  314. for info in channel_infos:
  315. properties = {
  316. "content_type": "application/json",
  317. "app_id": info._dst_party_id,
  318. "message_id": name,
  319. "correlation_id": tag,
  320. "headers": headers
  321. }
  322. print(f"[federation._send_kv]info: {info}, properties: {properties}.")
  323. info.produce(body=data, properties=properties)
  324. def _get_partition_send_func(
  325. self,
  326. name,
  327. tag,
  328. partitions,
  329. party_topic_infos,
  330. src_party_id,
  331. src_role,
  332. mq,
  333. max_message_size,
  334. conf: dict,
  335. ):
  336. def _fn(index, kvs):
  337. return self._partition_send(
  338. index=index,
  339. kvs=kvs,
  340. name=name,
  341. tag=tag,
  342. partitions=partitions,
  343. party_topic_infos=party_topic_infos,
  344. src_party_id=src_party_id,
  345. src_role=src_role,
  346. mq=mq,
  347. max_message_size=max_message_size,
  348. conf=conf,
  349. )
  350. return _fn
  351. def _partition_send(
  352. self,
  353. index,
  354. kvs,
  355. name,
  356. tag,
  357. partitions,
  358. party_topic_infos,
  359. src_party_id,
  360. src_role,
  361. mq,
  362. max_message_size,
  363. conf: dict,
  364. ):
  365. channel_infos = self._get_channels_index(
  366. index=index, party_topic_infos=party_topic_infos, src_party_id=src_party_id, src_role=src_role, mq=mq,
  367. conf=conf
  368. )
  369. datastream = Datastream()
  370. base_message_key = str(index)
  371. message_key_idx = 0
  372. count = 0
  373. for k, v in kvs:
  374. count += 1
  375. el = {"k": p_dumps(k).hex(), "v": p_dumps(v).hex()}
  376. # roughly caculate the size of package to avoid serialization ;)
  377. if (
  378. datastream.get_size() + sys.getsizeof(el["k"]) + sys.getsizeof(el["v"])
  379. >= max_message_size
  380. ):
  381. print(
  382. f"[federation._partition_send]The size of message is: {datastream.get_size()}"
  383. )
  384. message_key_idx += 1
  385. message_key = base_message_key + "_" + str(message_key_idx)
  386. self._send_kv(
  387. name=name,
  388. tag=tag,
  389. data=datastream.get_data().encode(),
  390. channel_infos=channel_infos,
  391. partition_size=-1,
  392. partitions=partitions,
  393. message_key=message_key,
  394. )
  395. datastream.clear()
  396. datastream.append(el)
  397. message_key_idx += 1
  398. message_key = _SPLIT_.join([base_message_key, str(message_key_idx)])
  399. self._send_kv(
  400. name=name,
  401. tag=tag,
  402. data=datastream.get_data().encode(),
  403. channel_infos=channel_infos,
  404. partition_size=count,
  405. partitions=partitions,
  406. message_key=message_key,
  407. )
  408. return [(index, 1)]
  409. def _get_message_cache_key(self, name, tag, party_id, role):
  410. cache_key = _SPLIT_.join([name, tag, str(party_id), role])
  411. return cache_key
  412. def _get_consume_message(self, channel_info):
  413. raise NotImplementedError()
  414. def _consume_ack(self, channel_info, id):
  415. raise NotImplementedError()
  416. def _query_receive_topic(self, channel_info):
  417. return channel_info
  418. def _receive_obj(self, channel_info, name, tag):
  419. party_id = channel_info._dst_party_id
  420. role = channel_info._dst_role
  421. wish_cache_key = self._get_message_cache_key(name, tag, party_id, role)
  422. if wish_cache_key in self._message_cache:
  423. recv_obj = self._message_cache[wish_cache_key]
  424. del self._message_cache[wish_cache_key]
  425. return recv_obj
  426. channel_info = self._query_receive_topic(channel_info)
  427. for id, properties, body in self._get_consume_message(channel_info):
  428. LOGGER.debug(
  429. f"[federation._receive_obj] properties: {properties}"
  430. )
  431. if properties["message_id"] != name or properties["correlation_id"] != tag:
  432. # todo: fix this
  433. LOGGER.warning(
  434. f"[federation._receive_obj] require {name}.{tag}, got {properties['message_id']}.{properties['correlation_id']}"
  435. )
  436. cache_key = self._get_message_cache_key(
  437. properties["message_id"], properties["correlation_id"], party_id, role
  438. )
  439. # object
  440. if properties["content_type"] == "text/plain":
  441. recv_obj = p_loads(body)
  442. self._consume_ack(channel_info, id)
  443. LOGGER.debug(
  444. f"[federation._receive_obj] cache_key: {cache_key}, wish_cache_key: {wish_cache_key}"
  445. )
  446. if cache_key == wish_cache_key:
  447. channel_info.cancel()
  448. return recv_obj
  449. else:
  450. self._message_cache[cache_key] = recv_obj
  451. else:
  452. raise ValueError(
  453. f"[federation._receive_obj] properties.content_type is {properties['content_type']}, but must be text/plain"
  454. )
  455. def _get_partition_receive_func(
  456. self, name, tag, src_party_id, src_role, dst_party_id, dst_role, topic_infos, mq, conf: dict
  457. ):
  458. def _fn(index, kvs):
  459. return self._partition_receive(
  460. index=index,
  461. kvs=kvs,
  462. name=name,
  463. tag=tag,
  464. src_party_id=src_party_id,
  465. src_role=src_role,
  466. dst_party_id=dst_party_id,
  467. dst_role=dst_role,
  468. topic_infos=topic_infos,
  469. mq=mq,
  470. conf=conf,
  471. )
  472. return _fn
  473. def _partition_receive(
  474. self,
  475. index,
  476. kvs,
  477. name,
  478. tag,
  479. src_party_id,
  480. src_role,
  481. dst_party_id,
  482. dst_role,
  483. topic_infos,
  484. mq,
  485. conf: dict,
  486. ):
  487. topic_pair = topic_infos[index][1]
  488. channel_info = self._get_channel(topic_pair=topic_pair,
  489. src_party_id=src_party_id,
  490. src_role=src_role,
  491. dst_party_id=dst_party_id,
  492. dst_role=dst_role,
  493. mq=mq,
  494. conf=conf)
  495. message_key_cache = set()
  496. count = 0
  497. partition_size = -1
  498. all_data = []
  499. channel_info = self._query_receive_topic(channel_info)
  500. while True:
  501. try:
  502. for id, properties, body in self._get_consume_message(channel_info):
  503. print(
  504. f"[federation._partition_receive] properties: {properties}."
  505. )
  506. if properties["message_id"] != name or properties["correlation_id"] != tag:
  507. # todo: fix this
  508. self._consume_ack(channel_info, id)
  509. print(
  510. f"[federation._partition_receive]: require {name}.{tag}, got {properties['message_id']}.{properties['correlation_id']}"
  511. )
  512. continue
  513. if properties["content_type"] == "application/json":
  514. header = json.loads(properties["headers"])
  515. message_key = header["message_key"]
  516. if message_key in message_key_cache:
  517. print(
  518. f"[federation._partition_receive] message_key : {message_key} is duplicated"
  519. )
  520. self._consume_ack(channel_info, id)
  521. continue
  522. message_key_cache.add(message_key)
  523. if header["partition_size"] >= 0:
  524. partition_size = header["partition_size"]
  525. data = json.loads(body.decode())
  526. data_iter = (
  527. (p_loads(bytes.fromhex(el["k"])), p_loads(bytes.fromhex(el["v"])))
  528. for el in data
  529. )
  530. count += len(data)
  531. print(f"[federation._partition_receive] count: {count}")
  532. all_data.extend(data_iter)
  533. self._consume_ack(channel_info, id)
  534. if count == partition_size:
  535. channel_info.cancel()
  536. return all_data
  537. else:
  538. ValueError(
  539. f"[federation._partition_receive]properties.content_type is {properties['content_type']}, but must be application/json"
  540. )
  541. except Exception as e:
  542. LOGGER.error(
  543. f"[federation._partition_receive]catch exception {e}, while receiving {name}.{tag}"
  544. )
  545. # avoid hang on consume()
  546. if count == partition_size:
  547. channel_info.cancel()
  548. return all_data
  549. else:
  550. raise e