|
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- from fate_arch.common import Party
- from fate_arch.common import file_utils
- from fate_arch.common.log import getLogger
- from fate_arch.federation._federation import FederationBase
- from fate_arch.federation.rabbitmq._mq_channel import MQChannel
- from fate_arch.federation.rabbitmq._rabbit_manager import RabbitManager
- LOGGER = getLogger()
- # default message max size in bytes = 1MB
- DEFAULT_MESSAGE_MAX_SIZE = 1048576
- class MQ(object):
- def __init__(self, host, port, union_name, policy_id, route_table):
- self.host = host
- self.port = port
- self.union_name = union_name
- self.policy_id = policy_id
- self.route_table = route_table
- def __str__(self):
- return (
- f"MQ(host={self.host}, port={self.port}, union_name={self.union_name}, "
- f"policy_id={self.policy_id}, route_table={self.route_table})"
- )
- def __repr__(self):
- return self.__str__()
- class _TopicPair(object):
- def __init__(self, tenant=None, namespace=None, vhost=None, send=None, receive=None):
- self.tenant = tenant
- self.namespace = namespace
- self.vhost = vhost
- self.send = send
- self.receive = receive
- class Federation(FederationBase):
- @staticmethod
- def from_conf(
- federation_session_id: str,
- party: Party,
- runtime_conf: dict,
- **kwargs
- ):
- rabbitmq_config = kwargs["rabbitmq_config"]
- LOGGER.debug(f"rabbitmq_config: {rabbitmq_config}")
- host = rabbitmq_config.get("host")
- port = rabbitmq_config.get("port")
- mng_port = rabbitmq_config.get("mng_port")
- base_user = rabbitmq_config.get("user")
- base_password = rabbitmq_config.get("password")
- mode = rabbitmq_config.get("mode", "replication")
- # max_message_size;
- max_message_size = int(rabbitmq_config.get("max_message_size", DEFAULT_MESSAGE_MAX_SIZE))
- union_name = federation_session_id
- policy_id = federation_session_id
- rabbitmq_run = runtime_conf.get("job_parameters", {}).get("rabbitmq_run", {})
- LOGGER.debug(f"rabbitmq_run: {rabbitmq_run}")
- max_message_size = int(rabbitmq_run.get(
- "max_message_size", max_message_size))
- LOGGER.debug(f"set max message size to {max_message_size} Bytes")
- rabbit_manager = RabbitManager(
- base_user, base_password, f"{host}:{mng_port}", rabbitmq_run
- )
- rabbit_manager.create_user(union_name, policy_id)
- route_table_path = rabbitmq_config.get("route_table")
- if route_table_path is None:
- route_table_path = "conf/rabbitmq_route_table.yaml"
- route_table = file_utils.load_yaml_conf(conf_path=route_table_path)
- mq = MQ(host, port, union_name, policy_id, route_table)
- conf = rabbit_manager.runtime_config.get(
- "connection", {}
- )
- return Federation(
- federation_session_id, party, mq, rabbit_manager, max_message_size, conf, mode
- )
- def __init__(self, session_id, party: Party, mq: MQ, rabbit_manager: RabbitManager, max_message_size, conf, mode):
- super().__init__(session_id=session_id, party=party, mq=mq, max_message_size=max_message_size, conf=conf)
- self._rabbit_manager = rabbit_manager
- self._vhost_set = set()
- self._mode = mode
- def __getstate__(self):
- pass
- def destroy(self, parties):
- LOGGER.debug("[rabbitmq.cleanup]start to cleanup...")
- for party in parties:
- if self._party == party:
- continue
- vhost = self._get_vhost(party)
- LOGGER.debug(f"[rabbitmq.cleanup]start to cleanup vhost {vhost}...")
- self._rabbit_manager.clean(vhost)
- LOGGER.debug(f"[rabbitmq.cleanup]cleanup vhost {vhost} done")
- if self._mq.union_name:
- LOGGER.debug(f"[rabbitmq.cleanup]clean user {self._mq.union_name}.")
- self._rabbit_manager.delete_user(user=self._mq.union_name)
- def _get_vhost(self, party):
- low, high = (
- (self._party, party) if self._party < party else (party, self._party)
- )
- vhost = (
- f"{self._session_id}-{low.role}-{low.party_id}-{high.role}-{high.party_id}"
- )
- return vhost
- def _maybe_create_topic_and_replication(self, party, topic_suffix):
- if self._mode == "replication":
- return self._create_topic_by_replication_mode(party, topic_suffix)
- if self._mode == "client":
- return self._create_topic_by_client_mode(party, topic_suffix)
- raise ValueError("mode={self._mode} is not support!")
- def _create_topic_by_client_mode(self, party, topic_suffix):
- # gen names
- vhost_name = self._get_vhost(party)
- send_queue_name = f"{self._session_id}-{self._party.role}-{self._party.party_id}-{party.role}-{party.party_id}-{topic_suffix}"
- receive_queue_name = f"{self._session_id}-{party.role}-{party.party_id}-{self._party.role}-{self._party.party_id}-{topic_suffix}"
- topic_pair = _TopicPair(
- namespace=self._session_id,
- vhost=vhost_name,
- send=send_queue_name,
- receive=receive_queue_name
- )
- # initial vhost
- if topic_pair.vhost not in self._vhost_set:
- self._rabbit_manager.create_vhost(topic_pair.vhost)
- self._rabbit_manager.add_user_to_vhost(
- self._mq.union_name, topic_pair.vhost
- )
- self._vhost_set.add(topic_pair.vhost)
- # initial send queue, the name is send-${vhost}
- self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.send)
- # initial receive queue, the name is receive-${vhost}
- self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.receive)
- return topic_pair
- def _create_topic_by_replication_mode(self, party, topic_suffix):
- # gen names
- vhost_name = self._get_vhost(party)
- send_queue_name = f"send-{self._session_id}-{self._party.role}-{self._party.party_id}-{party.role}-{party.party_id}-{topic_suffix}"
- receive_queue_name = f"receive-{self._session_id}-{party.role}-{party.party_id}-{self._party.role}-{self._party.party_id}-{topic_suffix}"
- topic_pair = _TopicPair(
- namespace=self._session_id,
- vhost=vhost_name,
- send=send_queue_name,
- receive=receive_queue_name
- )
- # initial vhost
- if topic_pair.vhost not in self._vhost_set:
- self._rabbit_manager.create_vhost(topic_pair.vhost)
- self._rabbit_manager.add_user_to_vhost(
- self._mq.union_name, topic_pair.vhost
- )
- self._vhost_set.add(topic_pair.vhost)
- # initial send queue, the name is send-${vhost}
- self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.send)
- # initial receive queue, the name is receive-${vhost}
- self._rabbit_manager.create_queue(
- topic_pair.vhost, topic_pair.receive
- )
- upstream_uri = self._upstream_uri(party_id=party.party_id)
- self._rabbit_manager.federate_queue(
- upstream_host=upstream_uri,
- vhost=topic_pair.vhost,
- send_queue_name=topic_pair.send,
- receive_queue_name=topic_pair.receive,
- )
- return topic_pair
- def _upstream_uri(self, party_id):
- host = self._mq.route_table.get(int(party_id)).get("host")
- port = self._mq.route_table.get(int(party_id)).get("port")
- upstream_uri = (
- f"amqp://{self._mq.union_name}:{self._mq.policy_id}@{host}:{port}"
- )
- return upstream_uri
- def _get_channel(
- self, topic_pair, src_party_id, src_role, dst_party_id, dst_role, mq=None, conf: dict = None):
- LOGGER.debug(f"rabbitmq federation _get_channel, src_party_id={src_party_id}, src_role={src_role},"
- f"dst_party_id={dst_party_id}, dst_role={dst_role}")
- return MQChannel(
- host=mq.host,
- port=mq.port,
- user=mq.union_name,
- password=mq.policy_id,
- namespace=topic_pair.namespace,
- vhost=topic_pair.vhost,
- send_queue_name=topic_pair.send,
- receive_queue_name=topic_pair.receive,
- src_party_id=src_party_id,
- src_role=src_role,
- dst_party_id=dst_party_id,
- dst_role=dst_role,
- extra_args=conf,
- )
- def _get_consume_message(self, channel_info):
- for method, properties, body in channel_info.consume():
- LOGGER.debug(
- f"[rabbitmq._get_consume_message] method: {method}, properties: {properties}"
- )
- properties = {
- "message_id": properties.message_id,
- "correlation_id": properties.correlation_id,
- "content_type": properties.content_type,
- "headers": json.dumps(properties.headers)
- }
- yield method.delivery_tag, properties, body
- def _consume_ack(self, channel_info, id):
- channel_info.ack(delivery_tag=id)
|