diff --git a/nlds/rabbit/consumer.py b/nlds/rabbit/consumer.py index e0d91c9d..f5efe8f2 100644 --- a/nlds/rabbit/consumer.py +++ b/nlds/rabbit/consumer.py @@ -21,6 +21,7 @@ from json.decoder import JSONDecodeError from urllib3.exceptions import HTTPError import signal +import threading as thr from pika.exceptions import StreamLostError, AMQPConnectionError from pika.channel import Channel @@ -207,6 +208,9 @@ def __init__(self, queue: str = None, setup_logging_fl=False): # caught in the callback. self.print_tracebacks_fl = True + # List of active threads created by consumption process + self.threads = [] + # Set up the logging and pass through constructor parameter self.setup_logging(enable=setup_logging_fl) @@ -675,9 +679,16 @@ def _wrapped_callback(self, ch: Channel, method: Method, properties: Header, This should be performed on all consumers and should be left untouched in child implementations. """ + ack_fl = None + # Begin the + self.keepalive.start_polling() + + # Get and log thread info + thread_id = thr.get_ident() + self.log(f"Callback started in thread {thread_id}", self.RK_LOG_INFO) + # Wrap callback with a try-except catching a selection of common # errors which can be caught without stopping consumption. - ack_fl = None try: ack_fl = self.callback(ch, method, properties, body, connection) except self.EXPECTED_EXCEPTIONS as original_error: @@ -702,6 +713,33 @@ def _wrapped_callback(self, ch: Channel, method: Method, properties: Header, self.nack_message(ch, method.delivery_tag, connection) else: self.acknowledge_message(ch, method.delivery_tag, connection) + + # Clear the consuming event so the keepalive stops polling the connection + self.keepalive.stop_polling() + + + def _start_callback_threaded( + self, + ch: Channel, + method: Method, + properties: Header, + body: bytes, + connection: Connection + ) -> None: + """Consumption method which starts the _wrapped_callback() in a new + thread so the connection can be kept alive in the main thread. Allows + for both long-running tasks and multithreading, though in the case of + the latter this may not be as efficient as just scaling out to another + consumer. + + TODO: (2024-02-08) This currently doesn't work but I'll leave it here in + case we return to it further down the line. + """ + t = thr.Thread(target=self._wrapped_callback, args=(ch, method, + properties, + body, connection)) + t.start() + self.threads.append() def declare_bindings(self) -> None: @@ -798,3 +836,10 @@ def run(self): self.channel.stop_consuming() + # Wait for all threads to complete + # TODO: what happens if we try to sigterm? + for t in self.threads: + t.join() + + self.connection.close() + diff --git a/nlds/rabbit/keepalive.py b/nlds/rabbit/keepalive.py new file mode 100644 index 00000000..954f3e2f --- /dev/null +++ b/nlds/rabbit/keepalive.py @@ -0,0 +1,63 @@ +from uuid import uuid4 +import threading as thr +from typing import List +import time + +from pika.connection import Connection + +class KeepaliveDaemon(): + """Class for orchestrating a connection keepalive daemon thread.""" + + def __init__(self, connection: Connection, heartbeat: int): + self.name = uuid4() + self.heartbeat = heartbeat + self.connection = connection + self.poll_event = thr.Event() + self.kill_event = thr.Event() + self.keepalive = None + + @staticmethod + def get_thread_names() -> List[str]: + return list(thr.name for thr in thr.enumerate()) + + def start_polling(self): + self.poll_event.set() + + def stop_polling(self): + self.poll_event.clear() + + def start(self) -> None: + # Create a keepalive daemon thread named after the uuid of this object + if self.name not in self.get_thread_names(): + # Start a deamon thread which processes data events in the + # background + self.poll_event.clear() + self.kill_event.clear() + self.keepalive = thr.Thread( + name=self.name, target=self.run, + # args=(self.connection, self.heartbeat, + # self.poll_event, self.kill_event), + daemon=True, + ) + self.keepalive.start() + + def kill(self): + if self.name in self.get_thread_names(): + self.kill_event.set() + # self.keepalive.join() + + def run(self): + """Simple infinite loop which keeps the connection alive by calling + process_data_events() during consumption. This is intended to be run in + the background as a daemon thread, can be exited immediately at main + thread exit. Needs to be passed the active connection object and the + required heartbeats. + """ + # While we have an open connection continue the process + while self.connection.is_open and not self.kill_event.is_set(): + print(f"{self.name}: {self.get_thread_names()}, {self.poll_event.is_set()}") + # If we're actively consuming and the connection is blocked, then + # periodically call process_data_events to keep the connection open. + if self.poll_event.is_set(): + self.connection.process_data_events() + time.sleep(max(self.heartbeat/2, 1)) \ No newline at end of file diff --git a/nlds/rabbit/publisher.py b/nlds/rabbit/publisher.py index 1a22976e..b6e09f3b 100644 --- a/nlds/rabbit/publisher.py +++ b/nlds/rabbit/publisher.py @@ -16,8 +16,11 @@ from typing import Dict, List, Any import pathlib from collections.abc import Sequence +import threading as thr +import time import pika +from pika.connection import Connection from pika.exceptions import AMQPConnectionError, UnroutableError, ChannelWrongStateError from retry import retry @@ -35,6 +38,7 @@ LOGGING_CONFIG_ENABLE, LOGGING_CONFIG_BACKUP_COUNT, ) +from .keepalive import KeepaliveDaemon from ..errors import RabbitRetryError logger = logging.getLogger("nlds.root") @@ -195,6 +199,9 @@ def __init__(self, name="publisher", setup_logging_fl=False): self.connection = None self.channel = None + self.heartbeat = self.config.get("heartbeat") or 300 + self.keepalive = None + try: # Do some basic verification of the general retry delays. self.retry_delays = self.general_config[self.RETRY_DELAYS] @@ -208,6 +215,7 @@ def __init__(self, name="publisher", setup_logging_fl=False): if setup_logging_fl: self.setup_logging() + @retry( RabbitRetryError, tries=-1, @@ -222,7 +230,11 @@ def get_connection(self): # Get the username and password for rabbit rabbit_user = self.config["user"] rabbit_password = self.config["password"] - connection_heartbeat = self.config.get("heartbeat") or 300 + + # Kill any daemon threads before we make a new one for the new + # connection + if self.keepalive: + self.keepalive.kill() # Start the rabbitMQ connection connection = pika.BlockingConnection( @@ -231,9 +243,11 @@ def get_connection(self): credentials=pika.PlainCredentials(rabbit_user, rabbit_password), virtual_host=self.config["vhost"], - heartbeat=connection_heartbeat, + heartbeat=self.heartbeat, ) ) + self.keepalive = KeepaliveDaemon(connection, self.heartbeat) + self.keepalive.start() # Create a new channel with basic qos channel = connection.channel() @@ -252,6 +266,7 @@ def get_connection(self): logger.debug(f"{type(e).__name__}: {e}") raise RabbitRetryError(str(e), ampq_exception=e) + def declare_bindings(self) -> None: """Go through list of exchanges from config file and declare each. Will also declare delayed exchanges for use in scheduled messaging if the @@ -359,6 +374,7 @@ def publish_message(self, # the message will never be sent. # raise RabbitRetryError(str(e), ampq_exception=e) + def get_retry_delay(self, retries: int): """Simple convenience function for getting the delay (in seconds) for an indexlist with a given number of retries. Works off of the member @@ -369,9 +385,11 @@ def get_retry_delay(self, retries: int): retries = min(retries, len(self.retry_delays) - 1) return int(self.retry_delays[retries]) + def close_connection(self) -> None: self.connection.close() + _default_logging_conf = { LOGGING_CONFIG_ENABLE: True, LOGGING_CONFIG_LEVEL: RK_LOG_INFO, @@ -543,6 +561,7 @@ def setup_logging( logger.warning(f"Failed to create log file for " f"{log_file}: {str(e)}") + def _log(self, log_message: str, log_level: str, target: str, **kwargs) -> None: """ @@ -581,6 +600,7 @@ def _log(self, log_message: str, log_level: str, target: str, message = self.create_log_message(log_message, target) self.publish_message(routing_key, message) + def log(self, log_message: str, log_level: str, target: str = None, **kwargs) -> None: # Attempt to log to publisher's name @@ -588,6 +608,7 @@ def log(self, log_message: str, log_level: str, target: str = None, target = self.name self._log(log_message, log_level, target, **kwargs) + @classmethod def create_log_message(cls, message: str, target: str, route: str = None) -> Dict[str, Any]: diff --git a/nlds_processors/transferers/put_transfer.py b/nlds_processors/transferers/put_transfer.py index 4d8eb68e..708ec471 100644 --- a/nlds_processors/transferers/put_transfer.py +++ b/nlds_processors/transferers/put_transfer.py @@ -92,7 +92,7 @@ def transfer(self, transaction_id: str, tenancy: str, access_key: str, # TODO: This begs the question of whether we need to store the # object-name at all path_details.object_name = path_details.original_path - try: + try: result = client.fput_object( bucket_name, path_details.object_name,