_federation.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. from fate_arch.common import Party
  18. from fate_arch.common import file_utils
  19. from fate_arch.common.log import getLogger
  20. from fate_arch.federation._federation import FederationBase
  21. from fate_arch.federation.rabbitmq._mq_channel import MQChannel
  22. from fate_arch.federation.rabbitmq._rabbit_manager import RabbitManager
  23. LOGGER = getLogger()
  24. # default message max size in bytes = 1MB
  25. DEFAULT_MESSAGE_MAX_SIZE = 1048576
  26. class MQ(object):
  27. def __init__(self, host, port, union_name, policy_id, route_table):
  28. self.host = host
  29. self.port = port
  30. self.union_name = union_name
  31. self.policy_id = policy_id
  32. self.route_table = route_table
  33. def __str__(self):
  34. return (
  35. f"MQ(host={self.host}, port={self.port}, union_name={self.union_name}, "
  36. f"policy_id={self.policy_id}, route_table={self.route_table})"
  37. )
  38. def __repr__(self):
  39. return self.__str__()
  40. class _TopicPair(object):
  41. def __init__(self, tenant=None, namespace=None, vhost=None, send=None, receive=None):
  42. self.tenant = tenant
  43. self.namespace = namespace
  44. self.vhost = vhost
  45. self.send = send
  46. self.receive = receive
  47. class Federation(FederationBase):
  48. @staticmethod
  49. def from_conf(
  50. federation_session_id: str,
  51. party: Party,
  52. runtime_conf: dict,
  53. **kwargs
  54. ):
  55. rabbitmq_config = kwargs["rabbitmq_config"]
  56. LOGGER.debug(f"rabbitmq_config: {rabbitmq_config}")
  57. host = rabbitmq_config.get("host")
  58. port = rabbitmq_config.get("port")
  59. mng_port = rabbitmq_config.get("mng_port")
  60. base_user = rabbitmq_config.get("user")
  61. base_password = rabbitmq_config.get("password")
  62. mode = rabbitmq_config.get("mode", "replication")
  63. # max_message_size;
  64. max_message_size = int(rabbitmq_config.get("max_message_size", DEFAULT_MESSAGE_MAX_SIZE))
  65. union_name = federation_session_id
  66. policy_id = federation_session_id
  67. rabbitmq_run = runtime_conf.get("job_parameters", {}).get("rabbitmq_run", {})
  68. LOGGER.debug(f"rabbitmq_run: {rabbitmq_run}")
  69. max_message_size = int(rabbitmq_run.get(
  70. "max_message_size", max_message_size))
  71. LOGGER.debug(f"set max message size to {max_message_size} Bytes")
  72. rabbit_manager = RabbitManager(
  73. base_user, base_password, f"{host}:{mng_port}", rabbitmq_run
  74. )
  75. rabbit_manager.create_user(union_name, policy_id)
  76. route_table_path = rabbitmq_config.get("route_table")
  77. if route_table_path is None:
  78. route_table_path = "conf/rabbitmq_route_table.yaml"
  79. route_table = file_utils.load_yaml_conf(conf_path=route_table_path)
  80. mq = MQ(host, port, union_name, policy_id, route_table)
  81. conf = rabbit_manager.runtime_config.get(
  82. "connection", {}
  83. )
  84. return Federation(
  85. federation_session_id, party, mq, rabbit_manager, max_message_size, conf, mode
  86. )
  87. def __init__(self, session_id, party: Party, mq: MQ, rabbit_manager: RabbitManager, max_message_size, conf, mode):
  88. super().__init__(session_id=session_id, party=party, mq=mq, max_message_size=max_message_size, conf=conf)
  89. self._rabbit_manager = rabbit_manager
  90. self._vhost_set = set()
  91. self._mode = mode
  92. def __getstate__(self):
  93. pass
  94. def destroy(self, parties):
  95. LOGGER.debug("[rabbitmq.cleanup]start to cleanup...")
  96. for party in parties:
  97. if self._party == party:
  98. continue
  99. vhost = self._get_vhost(party)
  100. LOGGER.debug(f"[rabbitmq.cleanup]start to cleanup vhost {vhost}...")
  101. self._rabbit_manager.clean(vhost)
  102. LOGGER.debug(f"[rabbitmq.cleanup]cleanup vhost {vhost} done")
  103. if self._mq.union_name:
  104. LOGGER.debug(f"[rabbitmq.cleanup]clean user {self._mq.union_name}.")
  105. self._rabbit_manager.delete_user(user=self._mq.union_name)
  106. def _get_vhost(self, party):
  107. low, high = (
  108. (self._party, party) if self._party < party else (party, self._party)
  109. )
  110. vhost = (
  111. f"{self._session_id}-{low.role}-{low.party_id}-{high.role}-{high.party_id}"
  112. )
  113. return vhost
  114. def _maybe_create_topic_and_replication(self, party, topic_suffix):
  115. if self._mode == "replication":
  116. return self._create_topic_by_replication_mode(party, topic_suffix)
  117. if self._mode == "client":
  118. return self._create_topic_by_client_mode(party, topic_suffix)
  119. raise ValueError("mode={self._mode} is not support!")
  120. def _create_topic_by_client_mode(self, party, topic_suffix):
  121. # gen names
  122. vhost_name = self._get_vhost(party)
  123. send_queue_name = f"{self._session_id}-{self._party.role}-{self._party.party_id}-{party.role}-{party.party_id}-{topic_suffix}"
  124. receive_queue_name = f"{self._session_id}-{party.role}-{party.party_id}-{self._party.role}-{self._party.party_id}-{topic_suffix}"
  125. topic_pair = _TopicPair(
  126. namespace=self._session_id,
  127. vhost=vhost_name,
  128. send=send_queue_name,
  129. receive=receive_queue_name
  130. )
  131. # initial vhost
  132. if topic_pair.vhost not in self._vhost_set:
  133. self._rabbit_manager.create_vhost(topic_pair.vhost)
  134. self._rabbit_manager.add_user_to_vhost(
  135. self._mq.union_name, topic_pair.vhost
  136. )
  137. self._vhost_set.add(topic_pair.vhost)
  138. # initial send queue, the name is send-${vhost}
  139. self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.send)
  140. # initial receive queue, the name is receive-${vhost}
  141. self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.receive)
  142. return topic_pair
  143. def _create_topic_by_replication_mode(self, party, topic_suffix):
  144. # gen names
  145. vhost_name = self._get_vhost(party)
  146. send_queue_name = f"send-{self._session_id}-{self._party.role}-{self._party.party_id}-{party.role}-{party.party_id}-{topic_suffix}"
  147. receive_queue_name = f"receive-{self._session_id}-{party.role}-{party.party_id}-{self._party.role}-{self._party.party_id}-{topic_suffix}"
  148. topic_pair = _TopicPair(
  149. namespace=self._session_id,
  150. vhost=vhost_name,
  151. send=send_queue_name,
  152. receive=receive_queue_name
  153. )
  154. # initial vhost
  155. if topic_pair.vhost not in self._vhost_set:
  156. self._rabbit_manager.create_vhost(topic_pair.vhost)
  157. self._rabbit_manager.add_user_to_vhost(
  158. self._mq.union_name, topic_pair.vhost
  159. )
  160. self._vhost_set.add(topic_pair.vhost)
  161. # initial send queue, the name is send-${vhost}
  162. self._rabbit_manager.create_queue(topic_pair.vhost, topic_pair.send)
  163. # initial receive queue, the name is receive-${vhost}
  164. self._rabbit_manager.create_queue(
  165. topic_pair.vhost, topic_pair.receive
  166. )
  167. upstream_uri = self._upstream_uri(party_id=party.party_id)
  168. self._rabbit_manager.federate_queue(
  169. upstream_host=upstream_uri,
  170. vhost=topic_pair.vhost,
  171. send_queue_name=topic_pair.send,
  172. receive_queue_name=topic_pair.receive,
  173. )
  174. return topic_pair
  175. def _upstream_uri(self, party_id):
  176. host = self._mq.route_table.get(int(party_id)).get("host")
  177. port = self._mq.route_table.get(int(party_id)).get("port")
  178. upstream_uri = (
  179. f"amqp://{self._mq.union_name}:{self._mq.policy_id}@{host}:{port}"
  180. )
  181. return upstream_uri
  182. def _get_channel(
  183. self, topic_pair, src_party_id, src_role, dst_party_id, dst_role, mq=None, conf: dict = None):
  184. LOGGER.debug(f"rabbitmq federation _get_channel, src_party_id={src_party_id}, src_role={src_role},"
  185. f"dst_party_id={dst_party_id}, dst_role={dst_role}")
  186. return MQChannel(
  187. host=mq.host,
  188. port=mq.port,
  189. user=mq.union_name,
  190. password=mq.policy_id,
  191. namespace=topic_pair.namespace,
  192. vhost=topic_pair.vhost,
  193. send_queue_name=topic_pair.send,
  194. receive_queue_name=topic_pair.receive,
  195. src_party_id=src_party_id,
  196. src_role=src_role,
  197. dst_party_id=dst_party_id,
  198. dst_role=dst_role,
  199. extra_args=conf,
  200. )
  201. def _get_consume_message(self, channel_info):
  202. for method, properties, body in channel_info.consume():
  203. LOGGER.debug(
  204. f"[rabbitmq._get_consume_message] method: {method}, properties: {properties}"
  205. )
  206. properties = {
  207. "message_id": properties.message_id,
  208. "correlation_id": properties.correlation_id,
  209. "content_type": properties.content_type,
  210. "headers": json.dumps(properties.headers)
  211. }
  212. yield method.delivery_tag, properties, body
  213. def _consume_ack(self, channel_info, id):
  214. channel_info.ack(delivery_tag=id)