123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629 |
- #
- # Copyright 2022 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- import sys
- import typing
- from pickle import dumps as p_dumps, loads as p_loads
- from fate_arch.abc import CTableABC
- from fate_arch.abc import FederationABC, GarbageCollectionABC
- from fate_arch.common import Party
- from fate_arch.common.log import getLogger
- from fate_arch.federation import FederationDataType
- from fate_arch.federation._datastream import Datastream
- from fate_arch.session import computing_session
- LOGGER = getLogger()
- NAME_DTYPE_TAG = "<dtype>"
- _SPLIT_ = "^"
- def _get_splits(obj, max_message_size):
- obj_bytes = p_dumps(obj, protocol=4)
- byte_size = len(obj_bytes)
- num_slice = (byte_size - 1) // max_message_size + 1
- if num_slice <= 1:
- return obj, num_slice
- else:
- _max_size = max_message_size
- kv = [(i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)]
- return kv, num_slice
- class FederationBase(FederationABC):
- @staticmethod
- def from_conf(
- federation_session_id: str,
- party: Party,
- runtime_conf: dict,
- **kwargs
- ):
- raise NotImplementedError()
- def __init__(
- self,
- session_id,
- party: Party,
- mq,
- max_message_size,
- conf=None
- ):
- self._session_id = session_id
- self._party = party
- self._mq = mq
- self._topic_map = {}
- self._channels_map = {}
- self._name_dtype_map = {}
- self._message_cache = {}
- self._max_message_size = max_message_size
- self._conf = conf
- def __getstate__(self):
- pass
- @property
- def session_id(self) -> str:
- return self._session_id
- def destroy(self, parties):
- raise NotImplementedError()
- def get(
- self, name: str, tag: str, parties: typing.List[Party], gc: GarbageCollectionABC
- ) -> typing.List:
- log_str = f"[federation.get](name={name}, tag={tag}, parties={parties})"
- LOGGER.debug(f"[{log_str}]start to get")
- _name_dtype_keys = [
- _SPLIT_.join([party.role, party.party_id, name, tag, "get"])
- for party in parties
- ]
- if _name_dtype_keys[0] not in self._name_dtype_map:
- party_topic_infos = self._get_party_topic_infos(parties, dtype=NAME_DTYPE_TAG)
- channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
- rtn_dtype = []
- for i, info in enumerate(channel_infos):
- obj = self._receive_obj(
- info, name, tag=_SPLIT_.join([tag, NAME_DTYPE_TAG])
- )
- rtn_dtype.append(obj)
- LOGGER.debug(
- f"[federation.get] _name_dtype_keys: {_name_dtype_keys}, dtype: {obj}"
- )
- for k in _name_dtype_keys:
- if k not in self._name_dtype_map:
- self._name_dtype_map[k] = rtn_dtype[0]
- rtn_dtype = self._name_dtype_map[_name_dtype_keys[0]]
- rtn = []
- dtype = rtn_dtype.get("dtype", None)
- partitions = rtn_dtype.get("partitions", None)
- if dtype == FederationDataType.TABLE or dtype == FederationDataType.SPLIT_OBJECT:
- party_topic_infos = self._get_party_topic_infos(parties, name, partitions=partitions)
- for i in range(len(party_topic_infos)):
- party = parties[i]
- role = party.role
- party_id = party.party_id
- topic_infos = party_topic_infos[i]
- receive_func = self._get_partition_receive_func(
- name=name,
- tag=tag,
- src_party_id=self._party.party_id,
- src_role=self._party.role,
- dst_party_id=party_id,
- dst_role=role,
- topic_infos=topic_infos,
- mq=self._mq,
- conf=self._conf
- )
- table = computing_session.parallelize(range(partitions), partitions, include_key=False)
- table = table.mapPartitionsWithIndex(receive_func)
- # add gc
- gc.add_gc_action(tag, table, "__del__", {})
- LOGGER.debug(
- f"[{log_str}]received table({i + 1}/{len(parties)}), party: {parties[i]} "
- )
- if dtype == FederationDataType.TABLE:
- rtn.append(table)
- else:
- obj_bytes = b''.join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])))
- obj = p_loads(obj_bytes)
- rtn.append(obj)
- else:
- party_topic_infos = self._get_party_topic_infos(parties, name)
- channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
- for i, info in enumerate(channel_infos):
- obj = self._receive_obj(info, name, tag)
- LOGGER.debug(
- f"[{log_str}]received obj({i + 1}/{len(parties)}), party: {parties[i]} "
- )
- rtn.append(obj)
- LOGGER.debug(f"[{log_str}]finish to get")
- return rtn
- def remote(
- self,
- v,
- name: str,
- tag: str,
- parties: typing.List[Party],
- gc: GarbageCollectionABC,
- ) -> typing.NoReturn:
- log_str = f"[federation.remote](name={name}, tag={tag}, parties={parties})"
- _name_dtype_keys = [
- _SPLIT_.join([party.role, party.party_id, name, tag, "remote"])
- for party in parties
- ]
- if _name_dtype_keys[0] not in self._name_dtype_map:
- party_topic_infos = self._get_party_topic_infos(parties, dtype=NAME_DTYPE_TAG)
- channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
- if not isinstance(v, CTableABC):
- v, num_slice = _get_splits(v, self._max_message_size)
- if num_slice > 1:
- v = computing_session.parallelize(data=v, partition=1, include_key=True)
- body = {"dtype": FederationDataType.SPLIT_OBJECT, "partitions": v.partitions}
- else:
- body = {"dtype": FederationDataType.OBJECT}
- else:
- body = {"dtype": FederationDataType.TABLE, "partitions": v.partitions}
- LOGGER.debug(
- f"[federation.remote] _name_dtype_keys: {_name_dtype_keys}, dtype: {body}"
- )
- self._send_obj(
- name=name,
- tag=_SPLIT_.join([tag, NAME_DTYPE_TAG]),
- data=p_dumps(body),
- channel_infos=channel_infos,
- )
- for k in _name_dtype_keys:
- if k not in self._name_dtype_map:
- self._name_dtype_map[k] = body
- if isinstance(v, CTableABC):
- total_size = v.count()
- partitions = v.partitions
- LOGGER.debug(
- f"[{log_str}]start to remote table, total_size={total_size}, partitions={partitions}"
- )
- party_topic_infos = self._get_party_topic_infos(parties, name, partitions=partitions)
- # add gc
- gc.add_gc_action(tag, v, "__del__", {})
- send_func = self._get_partition_send_func(
- name=name,
- tag=tag,
- partitions=partitions,
- party_topic_infos=party_topic_infos,
- src_party_id=self._party.party_id,
- src_role=self._party.role,
- mq=self._mq,
- max_message_size=self._max_message_size,
- conf=self._conf
- )
- # noinspection PyProtectedMember
- v.mapPartitionsWithIndex(send_func)
- else:
- LOGGER.debug(f"[{log_str}]start to remote obj")
- party_topic_infos = self._get_party_topic_infos(parties, name)
- channel_infos = self._get_channels(party_topic_infos=party_topic_infos)
- self._send_obj(
- name=name, tag=tag, data=p_dumps(v), channel_infos=channel_infos
- )
- LOGGER.debug(f"[{log_str}]finish to remote")
- def _get_party_topic_infos(
- self, parties: typing.List[Party], name=None, partitions=None, dtype=None
- ) -> typing.List:
- topic_infos = [
- self._get_or_create_topic(party, name, partitions, dtype)
- for party in parties
- ]
- return topic_infos
- def _maybe_create_topic_and_replication(self, party, topic_suffix):
- # gen names
- raise NotImplementedError()
- def _get_or_create_topic(
- self, party: Party, name=None, partitions=None, dtype=None
- ) -> typing.Tuple:
- topic_key_list = []
- topic_infos = []
- if dtype is not None:
- topic_key = _SPLIT_.join(
- [party.role, party.party_id, dtype, dtype])
- topic_key_list.append(topic_key)
- else:
- if partitions is not None:
- for i in range(partitions):
- topic_key = _SPLIT_.join(
- [party.role, party.party_id, name, str(i)])
- topic_key_list.append(topic_key)
- elif name is not None:
- topic_key = _SPLIT_.join([party.role, party.party_id, name])
- topic_key_list.append(topic_key)
- else:
- topic_key = _SPLIT_.join([party.role, party.party_id])
- topic_key_list.append(topic_key)
- for topic_key in topic_key_list:
- if topic_key not in self._topic_map:
- topic_key_splits = topic_key.split(_SPLIT_)
- topic_suffix = "-".join(topic_key_splits[2:])
- topic_pair = self._maybe_create_topic_and_replication(party, topic_suffix)
- self._topic_map[topic_key] = topic_pair
- topic_pair = self._topic_map[topic_key]
- topic_infos.append((topic_key, topic_pair))
- return topic_infos
- def _get_channel(
- self, topic_pair, src_party_id, src_role, dst_party_id, dst_role, mq=None, conf: dict = None):
- raise NotImplementedError()
- def _get_channels(self, party_topic_infos):
- channel_infos = []
- for e in party_topic_infos:
- for topic_key, topic_pair in e:
- topic_key_splits = topic_key.split(_SPLIT_)
- role = topic_key_splits[0]
- party_id = topic_key_splits[1]
- info = self._channels_map.get(topic_key)
- if info is None:
- info = self._get_channel(
- topic_pair=topic_pair,
- src_party_id=self._party.party_id,
- src_role=self._party.role,
- dst_party_id=party_id,
- dst_role=role,
- mq=self._mq,
- conf=self._conf
- )
- self._channels_map[topic_key] = info
- channel_infos.append(info)
- return channel_infos
- def _get_channels_index(self, index, party_topic_infos, src_party_id, src_role, mq=None, conf: dict = None):
- channel_infos = []
- for e in party_topic_infos:
- # select specified topic_info for a party
- topic_key, topic_pair = e[index]
- topic_key_splits = topic_key.split(_SPLIT_)
- role = topic_key_splits[0]
- party_id = topic_key_splits[1]
- info = self._get_channel(
- topic_pair=topic_pair,
- src_party_id=src_party_id,
- src_role=src_role,
- dst_party_id=party_id,
- dst_role=role,
- mq=mq,
- conf=conf
- )
- channel_infos.append(info)
- return channel_infos
- def _send_obj(self, name, tag, data, channel_infos):
- for info in channel_infos:
- properties = {
- "content_type": "text/plain",
- "app_id": info._dst_party_id,
- "message_id": name,
- "correlation_id": tag
- }
- LOGGER.debug(f"[federation._send_obj]properties:{properties}.")
- info.produce(body=data, properties=properties)
- def _send_kv(
- self, name, tag, data, channel_infos, partition_size, partitions, message_key
- ):
- headers = json.dumps(
- {
- "partition_size": partition_size,
- "partitions": partitions,
- "message_key": message_key
- }
- )
- for info in channel_infos:
- properties = {
- "content_type": "application/json",
- "app_id": info._dst_party_id,
- "message_id": name,
- "correlation_id": tag,
- "headers": headers
- }
- print(f"[federation._send_kv]info: {info}, properties: {properties}.")
- info.produce(body=data, properties=properties)
- def _get_partition_send_func(
- self,
- name,
- tag,
- partitions,
- party_topic_infos,
- src_party_id,
- src_role,
- mq,
- max_message_size,
- conf: dict,
- ):
- def _fn(index, kvs):
- return self._partition_send(
- index=index,
- kvs=kvs,
- name=name,
- tag=tag,
- partitions=partitions,
- party_topic_infos=party_topic_infos,
- src_party_id=src_party_id,
- src_role=src_role,
- mq=mq,
- max_message_size=max_message_size,
- conf=conf,
- )
- return _fn
- def _partition_send(
- self,
- index,
- kvs,
- name,
- tag,
- partitions,
- party_topic_infos,
- src_party_id,
- src_role,
- mq,
- max_message_size,
- conf: dict,
- ):
- channel_infos = self._get_channels_index(
- index=index, party_topic_infos=party_topic_infos, src_party_id=src_party_id, src_role=src_role, mq=mq,
- conf=conf
- )
- datastream = Datastream()
- base_message_key = str(index)
- message_key_idx = 0
- count = 0
- for k, v in kvs:
- count += 1
- el = {"k": p_dumps(k).hex(), "v": p_dumps(v).hex()}
- # roughly caculate the size of package to avoid serialization ;)
- if (
- datastream.get_size() + sys.getsizeof(el["k"]) + sys.getsizeof(el["v"])
- >= max_message_size
- ):
- print(
- f"[federation._partition_send]The size of message is: {datastream.get_size()}"
- )
- message_key_idx += 1
- message_key = base_message_key + "_" + str(message_key_idx)
- self._send_kv(
- name=name,
- tag=tag,
- data=datastream.get_data().encode(),
- channel_infos=channel_infos,
- partition_size=-1,
- partitions=partitions,
- message_key=message_key,
- )
- datastream.clear()
- datastream.append(el)
- message_key_idx += 1
- message_key = _SPLIT_.join([base_message_key, str(message_key_idx)])
- self._send_kv(
- name=name,
- tag=tag,
- data=datastream.get_data().encode(),
- channel_infos=channel_infos,
- partition_size=count,
- partitions=partitions,
- message_key=message_key,
- )
- return [(index, 1)]
- def _get_message_cache_key(self, name, tag, party_id, role):
- cache_key = _SPLIT_.join([name, tag, str(party_id), role])
- return cache_key
- def _get_consume_message(self, channel_info):
- raise NotImplementedError()
- def _consume_ack(self, channel_info, id):
- raise NotImplementedError()
- def _query_receive_topic(self, channel_info):
- return channel_info
- def _receive_obj(self, channel_info, name, tag):
- party_id = channel_info._dst_party_id
- role = channel_info._dst_role
- wish_cache_key = self._get_message_cache_key(name, tag, party_id, role)
- if wish_cache_key in self._message_cache:
- recv_obj = self._message_cache[wish_cache_key]
- del self._message_cache[wish_cache_key]
- return recv_obj
- channel_info = self._query_receive_topic(channel_info)
- for id, properties, body in self._get_consume_message(channel_info):
- LOGGER.debug(
- f"[federation._receive_obj] properties: {properties}"
- )
- if properties["message_id"] != name or properties["correlation_id"] != tag:
- # todo: fix this
- LOGGER.warning(
- f"[federation._receive_obj] require {name}.{tag}, got {properties['message_id']}.{properties['correlation_id']}"
- )
- cache_key = self._get_message_cache_key(
- properties["message_id"], properties["correlation_id"], party_id, role
- )
- # object
- if properties["content_type"] == "text/plain":
- recv_obj = p_loads(body)
- self._consume_ack(channel_info, id)
- LOGGER.debug(
- f"[federation._receive_obj] cache_key: {cache_key}, wish_cache_key: {wish_cache_key}"
- )
- if cache_key == wish_cache_key:
- channel_info.cancel()
- return recv_obj
- else:
- self._message_cache[cache_key] = recv_obj
- else:
- raise ValueError(
- f"[federation._receive_obj] properties.content_type is {properties['content_type']}, but must be text/plain"
- )
- def _get_partition_receive_func(
- self, name, tag, src_party_id, src_role, dst_party_id, dst_role, topic_infos, mq, conf: dict
- ):
- def _fn(index, kvs):
- return self._partition_receive(
- index=index,
- kvs=kvs,
- name=name,
- tag=tag,
- src_party_id=src_party_id,
- src_role=src_role,
- dst_party_id=dst_party_id,
- dst_role=dst_role,
- topic_infos=topic_infos,
- mq=mq,
- conf=conf,
- )
- return _fn
- def _partition_receive(
- self,
- index,
- kvs,
- name,
- tag,
- src_party_id,
- src_role,
- dst_party_id,
- dst_role,
- topic_infos,
- mq,
- conf: dict,
- ):
- topic_pair = topic_infos[index][1]
- channel_info = self._get_channel(topic_pair=topic_pair,
- src_party_id=src_party_id,
- src_role=src_role,
- dst_party_id=dst_party_id,
- dst_role=dst_role,
- mq=mq,
- conf=conf)
- message_key_cache = set()
- count = 0
- partition_size = -1
- all_data = []
- channel_info = self._query_receive_topic(channel_info)
- while True:
- try:
- for id, properties, body in self._get_consume_message(channel_info):
- print(
- f"[federation._partition_receive] properties: {properties}."
- )
- if properties["message_id"] != name or properties["correlation_id"] != tag:
- # todo: fix this
- self._consume_ack(channel_info, id)
- print(
- f"[federation._partition_receive]: require {name}.{tag}, got {properties['message_id']}.{properties['correlation_id']}"
- )
- continue
- if properties["content_type"] == "application/json":
- header = json.loads(properties["headers"])
- message_key = header["message_key"]
- if message_key in message_key_cache:
- print(
- f"[federation._partition_receive] message_key : {message_key} is duplicated"
- )
- self._consume_ack(channel_info, id)
- continue
- message_key_cache.add(message_key)
- if header["partition_size"] >= 0:
- partition_size = header["partition_size"]
- data = json.loads(body.decode())
- data_iter = (
- (p_loads(bytes.fromhex(el["k"])), p_loads(bytes.fromhex(el["v"])))
- for el in data
- )
- count += len(data)
- print(f"[federation._partition_receive] count: {count}")
- all_data.extend(data_iter)
- self._consume_ack(channel_info, id)
- if count == partition_size:
- channel_info.cancel()
- return all_data
- else:
- ValueError(
- f"[federation._partition_receive]properties.content_type is {properties['content_type']}, but must be application/json"
- )
- except Exception as e:
- LOGGER.error(
- f"[federation._partition_receive]catch exception {e}, while receiving {name}.{tag}"
- )
- # avoid hang on consume()
- if count == partition_size:
- channel_info.cancel()
- return all_data
- else:
- raise e
|