_mq_channel.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import pika
  18. from fate_arch.common import log
  19. from fate_arch.federation._nretry import nretry
  20. LOGGER = log.getLogger()
  21. class MQChannel(object):
  22. def __init__(self,
  23. host,
  24. port,
  25. user,
  26. password,
  27. namespace,
  28. vhost,
  29. send_queue_name,
  30. receive_queue_name,
  31. src_party_id,
  32. src_role,
  33. dst_party_id,
  34. dst_role,
  35. extra_args: dict):
  36. self._host = host
  37. self._port = port
  38. self._credentials = pika.PlainCredentials(user, password)
  39. self._namespace = namespace
  40. self._vhost = vhost
  41. self._send_queue_name = send_queue_name
  42. self._receive_queue_name = receive_queue_name
  43. self._src_party_id = src_party_id
  44. self._src_role = src_role
  45. self._dst_party_id = dst_party_id
  46. self._dst_role = dst_role
  47. self._conn = None
  48. self._channel = None
  49. self._extra_args = extra_args
  50. if "heartbeat" not in self._extra_args:
  51. self._extra_args["heartbeat"] = 3600
  52. def __str__(self):
  53. return (
  54. f"MQChannel(host={self._host}, port={self._port}, namespace={self._namespace}, "
  55. f"src_party_id={self._src_party_id}, src_role={self._src_role},"
  56. f"dst_party_id={self._dst_party_id}, dst_role={self._dst_role},"
  57. f"send_queue_name={self._send_queue_name}, receive_queue_name={self._receive_queue_name}),"
  58. )
  59. def __repr__(self):
  60. return self.__str__()
  61. @nretry
  62. def produce(self, body, properties: dict):
  63. self._get_channel()
  64. LOGGER.debug(f"send queue: {self._send_queue_name}")
  65. if "headers" in properties:
  66. headers = json.loads(properties["headers"])
  67. else:
  68. headers = {}
  69. properties = pika.BasicProperties(
  70. content_type=properties["content_type"],
  71. app_id=properties["app_id"],
  72. message_id=properties["message_id"],
  73. correlation_id=properties["correlation_id"],
  74. headers=headers,
  75. delivery_mode=1,
  76. )
  77. return self._channel.basic_publish(exchange='', routing_key=self._send_queue_name, body=body,
  78. properties=properties)
  79. @nretry
  80. def consume(self):
  81. self._get_channel()
  82. LOGGER.debug(f"receive queue: {self._receive_queue_name}")
  83. return self._channel.consume(queue=self._receive_queue_name)
  84. @nretry
  85. def ack(self, delivery_tag):
  86. self._get_channel()
  87. return self._channel.basic_ack(delivery_tag=delivery_tag)
  88. @nretry
  89. def cancel(self):
  90. self._get_channel()
  91. return self._channel.cancel()
  92. def _get_channel(self):
  93. if self._check_alive():
  94. return
  95. else:
  96. self._clear()
  97. if not self._conn:
  98. self._conn = pika.BlockingConnection(pika.ConnectionParameters(host=self._host, port=self._port,
  99. virtual_host=self._vhost,
  100. credentials=self._credentials,
  101. **self._extra_args))
  102. if not self._channel:
  103. self._channel = self._conn.channel()
  104. self._channel.confirm_delivery()
  105. def _clear(self):
  106. try:
  107. if self._conn and self._conn.is_open:
  108. self._conn.close()
  109. self._conn = None
  110. if self._channel and self._channel.is_open:
  111. self._channel.close()
  112. self._channel = None
  113. except Exception as e:
  114. LOGGER.exception(e)
  115. self._conn = None
  116. self._channel = None
  117. def _check_alive(self):
  118. return self._channel and self._channel.is_open and self._conn and self._conn.is_open