123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- import pika
- from fate_arch.common import log
- from fate_arch.federation._nretry import nretry
- LOGGER = log.getLogger()
- class MQChannel(object):
- def __init__(self,
- host,
- port,
- user,
- password,
- namespace,
- vhost,
- send_queue_name,
- receive_queue_name,
- src_party_id,
- src_role,
- dst_party_id,
- dst_role,
- extra_args: dict):
- self._host = host
- self._port = port
- self._credentials = pika.PlainCredentials(user, password)
- self._namespace = namespace
- self._vhost = vhost
- self._send_queue_name = send_queue_name
- self._receive_queue_name = receive_queue_name
- self._src_party_id = src_party_id
- self._src_role = src_role
- self._dst_party_id = dst_party_id
- self._dst_role = dst_role
- self._conn = None
- self._channel = None
- self._extra_args = extra_args
- if "heartbeat" not in self._extra_args:
- self._extra_args["heartbeat"] = 3600
- def __str__(self):
- return (
- f"MQChannel(host={self._host}, port={self._port}, namespace={self._namespace}, "
- f"src_party_id={self._src_party_id}, src_role={self._src_role},"
- f"dst_party_id={self._dst_party_id}, dst_role={self._dst_role},"
- f"send_queue_name={self._send_queue_name}, receive_queue_name={self._receive_queue_name}),"
- )
- def __repr__(self):
- return self.__str__()
- @nretry
- def produce(self, body, properties: dict):
- self._get_channel()
- LOGGER.debug(f"send queue: {self._send_queue_name}")
- if "headers" in properties:
- headers = json.loads(properties["headers"])
- else:
- headers = {}
- properties = pika.BasicProperties(
- content_type=properties["content_type"],
- app_id=properties["app_id"],
- message_id=properties["message_id"],
- correlation_id=properties["correlation_id"],
- headers=headers,
- delivery_mode=1,
- )
- return self._channel.basic_publish(exchange='', routing_key=self._send_queue_name, body=body,
- properties=properties)
- @nretry
- def consume(self):
- self._get_channel()
- LOGGER.debug(f"receive queue: {self._receive_queue_name}")
- return self._channel.consume(queue=self._receive_queue_name)
- @nretry
- def ack(self, delivery_tag):
- self._get_channel()
- return self._channel.basic_ack(delivery_tag=delivery_tag)
- @nretry
- def cancel(self):
- self._get_channel()
- return self._channel.cancel()
- def _get_channel(self):
- if self._check_alive():
- return
- else:
- self._clear()
- if not self._conn:
- self._conn = pika.BlockingConnection(pika.ConnectionParameters(host=self._host, port=self._port,
- virtual_host=self._vhost,
- credentials=self._credentials,
- **self._extra_args))
- if not self._channel:
- self._channel = self._conn.channel()
- self._channel.confirm_delivery()
- def _clear(self):
- try:
- if self._conn and self._conn.is_open:
- self._conn.close()
- self._conn = None
- if self._channel and self._channel.is_open:
- self._channel.close()
- self._channel = None
- except Exception as e:
- LOGGER.exception(e)
- self._conn = None
- self._channel = None
- def _check_alive(self):
- return self._channel and self._channel.is_open and self._conn and self._conn.is_open
|