Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement threaded consumers for connection keepalive during long-running consumption #90

Merged
merged 6 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion nlds/rabbit/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from json.decoder import JSONDecodeError
from urllib3.exceptions import HTTPError
import signal
import threading as thr

from pika.exceptions import StreamLostError, AMQPConnectionError
from pika.channel import Channel
Expand Down Expand Up @@ -207,6 +208,9 @@ def __init__(self, queue: str = None, setup_logging_fl=False):
# caught in the callback.
self.print_tracebacks_fl = True

# List of active threads created by consumption process
self.threads = []

# Set up the logging and pass through constructor parameter
self.setup_logging(enable=setup_logging_fl)

Expand Down Expand Up @@ -675,9 +679,16 @@ def _wrapped_callback(self, ch: Channel, method: Method, properties: Header,
This should be performed on all consumers and should be left untouched
in child implementations.
"""
ack_fl = None
# Begin the
self.keepalive.start_polling()

# Get and log thread info
thread_id = thr.get_ident()
self.log(f"Callback started in thread {thread_id}", self.RK_LOG_INFO)

# Wrap callback with a try-except catching a selection of common
# errors which can be caught without stopping consumption.
ack_fl = None
try:
ack_fl = self.callback(ch, method, properties, body, connection)
except self.EXPECTED_EXCEPTIONS as original_error:
Expand All @@ -702,6 +713,33 @@ def _wrapped_callback(self, ch: Channel, method: Method, properties: Header,
self.nack_message(ch, method.delivery_tag, connection)
else:
self.acknowledge_message(ch, method.delivery_tag, connection)

# Clear the consuming event so the keepalive stops polling the connection
self.keepalive.stop_polling()


def _start_callback_threaded(
self,
ch: Channel,
method: Method,
properties: Header,
body: bytes,
connection: Connection
) -> None:
"""Consumption method which starts the _wrapped_callback() in a new
thread so the connection can be kept alive in the main thread. Allows
for both long-running tasks and multithreading, though in the case of
the latter this may not be as efficient as just scaling out to another
consumer.

TODO: (2024-02-08) This currently doesn't work but I'll leave it here in
case we return to it further down the line.
"""
t = thr.Thread(target=self._wrapped_callback, args=(ch, method,
properties,
body, connection))
t.start()
self.threads.append()


def declare_bindings(self) -> None:
Expand Down Expand Up @@ -798,3 +836,10 @@ def run(self):

self.channel.stop_consuming()

# Wait for all threads to complete
# TODO: what happens if we try to sigterm?
for t in self.threads:
t.join()

self.connection.close()

63 changes: 63 additions & 0 deletions nlds/rabbit/keepalive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from uuid import uuid4
import threading as thr
from typing import List
import time

from pika.connection import Connection

class KeepaliveDaemon():
"""Class for orchestrating a connection keepalive daemon thread."""

def __init__(self, connection: Connection, heartbeat: int):
self.name = uuid4()
self.heartbeat = heartbeat
self.connection = connection
self.poll_event = thr.Event()
self.kill_event = thr.Event()
self.keepalive = None

@staticmethod
def get_thread_names() -> List[str]:
return list(thr.name for thr in thr.enumerate())

def start_polling(self):
self.poll_event.set()

def stop_polling(self):
self.poll_event.clear()

def start(self) -> None:
# Create a keepalive daemon thread named after the uuid of this object
if self.name not in self.get_thread_names():
# Start a deamon thread which processes data events in the
# background
self.poll_event.clear()
self.kill_event.clear()
self.keepalive = thr.Thread(
name=self.name, target=self.run,
# args=(self.connection, self.heartbeat,
# self.poll_event, self.kill_event),
daemon=True,
)
self.keepalive.start()

def kill(self):
if self.name in self.get_thread_names():
self.kill_event.set()
# self.keepalive.join()

def run(self):
"""Simple infinite loop which keeps the connection alive by calling
process_data_events() during consumption. This is intended to be run in
the background as a daemon thread, can be exited immediately at main
thread exit. Needs to be passed the active connection object and the
required heartbeats.
"""
# While we have an open connection continue the process
while self.connection.is_open and not self.kill_event.is_set():
print(f"{self.name}: {self.get_thread_names()}, {self.poll_event.is_set()}")
# If we're actively consuming and the connection is blocked, then
# periodically call process_data_events to keep the connection open.
if self.poll_event.is_set():
self.connection.process_data_events()
time.sleep(max(self.heartbeat/2, 1))
25 changes: 23 additions & 2 deletions nlds/rabbit/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from typing import Dict, List, Any
import pathlib
from collections.abc import Sequence
import threading as thr
import time

import pika
from pika.connection import Connection
from pika.exceptions import AMQPConnectionError, UnroutableError, ChannelWrongStateError
from retry import retry

Expand All @@ -35,6 +38,7 @@
LOGGING_CONFIG_ENABLE,
LOGGING_CONFIG_BACKUP_COUNT,
)
from .keepalive import KeepaliveDaemon
from ..errors import RabbitRetryError

logger = logging.getLogger("nlds.root")
Expand Down Expand Up @@ -195,6 +199,9 @@ def __init__(self, name="publisher", setup_logging_fl=False):

self.connection = None
self.channel = None
self.heartbeat = self.config.get("heartbeat") or 300
self.keepalive = None

try:
# Do some basic verification of the general retry delays.
self.retry_delays = self.general_config[self.RETRY_DELAYS]
Expand All @@ -208,6 +215,7 @@ def __init__(self, name="publisher", setup_logging_fl=False):
if setup_logging_fl:
self.setup_logging()


@retry(
RabbitRetryError,
tries=-1,
Expand All @@ -222,7 +230,11 @@ def get_connection(self):
# Get the username and password for rabbit
rabbit_user = self.config["user"]
rabbit_password = self.config["password"]
connection_heartbeat = self.config.get("heartbeat") or 300

# Kill any daemon threads before we make a new one for the new
# connection
if self.keepalive:
self.keepalive.kill()

# Start the rabbitMQ connection
connection = pika.BlockingConnection(
Expand All @@ -231,9 +243,11 @@ def get_connection(self):
credentials=pika.PlainCredentials(rabbit_user,
rabbit_password),
virtual_host=self.config["vhost"],
heartbeat=connection_heartbeat,
heartbeat=self.heartbeat,
)
)
self.keepalive = KeepaliveDaemon(connection, self.heartbeat)
self.keepalive.start()

# Create a new channel with basic qos
channel = connection.channel()
Expand All @@ -252,6 +266,7 @@ def get_connection(self):
logger.debug(f"{type(e).__name__}: {e}")
raise RabbitRetryError(str(e), ampq_exception=e)


def declare_bindings(self) -> None:
"""Go through list of exchanges from config file and declare each. Will
also declare delayed exchanges for use in scheduled messaging if the
Expand Down Expand Up @@ -359,6 +374,7 @@ def publish_message(self,
# the message will never be sent.
# raise RabbitRetryError(str(e), ampq_exception=e)


def get_retry_delay(self, retries: int):
"""Simple convenience function for getting the delay (in seconds) for an
indexlist with a given number of retries. Works off of the member
Expand All @@ -369,9 +385,11 @@ def get_retry_delay(self, retries: int):
retries = min(retries, len(self.retry_delays) - 1)
return int(self.retry_delays[retries])


def close_connection(self) -> None:
self.connection.close()


_default_logging_conf = {
LOGGING_CONFIG_ENABLE: True,
LOGGING_CONFIG_LEVEL: RK_LOG_INFO,
Expand Down Expand Up @@ -543,6 +561,7 @@ def setup_logging(
logger.warning(f"Failed to create log file for "
f"{log_file}: {str(e)}")


def _log(self, log_message: str, log_level: str, target: str,
**kwargs) -> None:
"""
Expand Down Expand Up @@ -581,13 +600,15 @@ def _log(self, log_message: str, log_level: str, target: str,
message = self.create_log_message(log_message, target)
self.publish_message(routing_key, message)


def log(self, log_message: str, log_level: str, target: str = None,
**kwargs) -> None:
# Attempt to log to publisher's name
if not target:
target = self.name
self._log(log_message, log_level, target, **kwargs)


@classmethod
def create_log_message(cls, message: str, target: str,
route: str = None) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion nlds_processors/transferers/put_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def transfer(self, transaction_id: str, tenancy: str, access_key: str,
# TODO: This begs the question of whether we need to store the
# object-name at all
path_details.object_name = path_details.original_path
try:
try:
result = client.fput_object(
bucket_name,
path_details.object_name,
Expand Down
Loading