Skip to content

Adding handling of FAILING_OVER and FAILED_OVER events/push notifications #3716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: ps_hitless_upgrade_sync_redis
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 110 additions & 6 deletions redis/maintenance_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class MaintenanceState(enum.Enum):
NONE = "none"
MOVING = "moving"
MIGRATING = "migrating"
FAILING_OVER = "failing_over"


if TYPE_CHECKING:
Expand Down Expand Up @@ -261,6 +262,105 @@ def __hash__(self) -> int:
return hash((self.__class__, self.id))


class NodeFailingOverEvent(MaintenanceEvent):
"""
Event for when a Redis cluster node is in the process of failing over.

This event is received when a node starts a failover process during
cluster maintenance operations or when handling node failures.

Args:
id (int): Unique identifier for this event
ttl (int): Time-to-live in seconds for this notification
"""

def __init__(self, id: int, ttl: int):
super().__init__(id, ttl)

def __repr__(self) -> str:
expiry_time = self.creation_time + self.ttl
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
f"remaining={remaining:.1f}s, "
f"expired={self.is_expired()}"
f")"
)

def __eq__(self, other) -> bool:
"""
Two NodeFailingOverEvent events are considered equal if they have the same
id and are of the same type.
"""
if not isinstance(other, NodeFailingOverEvent):
return False
return self.id == other.id and type(self) is type(other)

def __hash__(self) -> int:
"""
Return a hash value for the event to allow
instances to be used in sets and as dictionary keys.

Returns:
int: Hash value based on event type and id
"""
return hash((self.__class__, self.id))


class NodeFailedOverEvent(MaintenanceEvent):
"""
Event for when a Redis cluster node has completed a failover.

This event is received when a node has finished the failover process
during cluster maintenance operations or after handling node failures.

Args:
id (int): Unique identifier for this event
"""

DEFAULT_TTL = 5

def __init__(self, id: int):
super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL)

def __repr__(self) -> str:
expiry_time = self.creation_time + self.ttl
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
f"remaining={remaining:.1f}s, "
f"expired={self.is_expired()}"
f")"
)

def __eq__(self, other) -> bool:
"""
Two NodeFailedOverEvent events are considered equal if they have the same
id and are of the same type.
"""
if not isinstance(other, NodeFailedOverEvent):
return False
return self.id == other.id and type(self) is type(other)

def __hash__(self) -> int:
"""
Return a hash value for the event to allow
instances to be used in sets and as dictionary keys.

Returns:
int: Hash value based on event type and id
"""
return hash((self.__class__, self.id))


class MaintenanceEventsConfig:
"""
Configuration class for maintenance events handling behaviour. Events are received through
Expand Down Expand Up @@ -446,32 +546,36 @@ def __init__(

def handle_event(self, event: MaintenanceEvent):
if isinstance(event, NodeMigratingEvent):
return self.handle_migrating_event(event)
return self.handle_maintenance_start_event(MaintenanceState.MIGRATING)
elif isinstance(event, NodeMigratedEvent):
return self.handle_migration_completed_event(event)
return self.handle_maintenance_completed_event()
elif isinstance(event, NodeFailingOverEvent):
return self.handle_maintenance_start_event(MaintenanceState.FAILING_OVER)
elif isinstance(event, NodeFailedOverEvent):
return self.handle_maintenance_completed_event()
else:
logging.error(f"Unhandled event type: {event}")

def handle_migrating_event(self, notification: NodeMigratingEvent):
def handle_maintenance_start_event(self, maintenance_state: MaintenanceState):
if (
self.connection.maintenance_state == MaintenanceState.MOVING
or not self.config.is_relax_timeouts_enabled()
):
return
self.connection.maintenance_state = MaintenanceState.MIGRATING
self.connection.maintenance_state = maintenance_state
self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout)
# extend the timeout for all created connections
self.connection.update_current_socket_timeout(self.config.relax_timeout)

def handle_migration_completed_event(self, notification: "NodeMigratedEvent"):
def handle_maintenance_completed_event(self):
# Only reset timeouts if state is not MOVING and relax timeouts are enabled
if (
self.connection.maintenance_state == MaintenanceState.MOVING
or not self.config.is_relax_timeouts_enabled()
):
return
self.connection.reset_tmp_settings(reset_relax_timeout=True)
# Node migration completed - reset the connection
# Maintenance completed - reset the connection
# timeouts by providing -1 as the relax timeout
self.connection.update_current_socket_timeout(-1)
self.connection.maintenance_state = MaintenanceState.NONE
1 change: 0 additions & 1 deletion tests/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import redis
from redis.cache import CacheConfig
from redis.connection import CacheProxyConnection, Connection, to_bool
from redis.maintenance_events import MaintenanceState
from redis.utils import SSL_AVAILABLE

from .conftest import (
Expand Down
171 changes: 151 additions & 20 deletions tests/test_maintenance_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
NodeMovingEvent,
NodeMigratingEvent,
NodeMigratedEvent,
NodeFailingOverEvent,
NodeFailedOverEvent,
MaintenanceEventsConfig,
MaintenanceEventPoolHandler,
MaintenanceEventConnectionHandler,
MaintenanceState,
)


Expand Down Expand Up @@ -281,6 +284,84 @@ def test_equality_and_hash(self):
assert hash(event1) != hash(event3)


class TestNodeFailingOverEvent:
"""Test the NodeFailingOverEvent class."""

def test_init(self):
"""Test NodeFailingOverEvent initialization."""
with patch("time.monotonic", return_value=1000):
event = NodeFailingOverEvent(id=1, ttl=5)
assert event.id == 1
assert event.ttl == 5
assert event.creation_time == 1000

def test_repr(self):
"""Test NodeFailingOverEvent string representation."""
with patch("time.monotonic", return_value=1000):
event = NodeFailingOverEvent(id=1, ttl=5)

with patch("time.monotonic", return_value=1002): # 2 seconds later
repr_str = repr(event)
assert "NodeFailingOverEvent" in repr_str
assert "id=1" in repr_str
assert "ttl=5" in repr_str
assert "remaining=3.0s" in repr_str
assert "expired=False" in repr_str

def test_equality_and_hash(self):
"""Test equality and hash for NodeFailingOverEvent."""
event1 = NodeFailingOverEvent(id=1, ttl=5)
event2 = NodeFailingOverEvent(id=1, ttl=10) # Same id, different ttl
event3 = NodeFailingOverEvent(id=2, ttl=5) # Different id

assert event1 == event2
assert event1 != event3
assert hash(event1) == hash(event2)
assert hash(event1) != hash(event3)


class TestNodeFailedOverEvent:
"""Test the NodeFailedOverEvent class."""

def test_init(self):
"""Test NodeFailedOverEvent initialization."""
with patch("time.monotonic", return_value=1000):
event = NodeFailedOverEvent(id=1)
assert event.id == 1
assert event.ttl == NodeFailedOverEvent.DEFAULT_TTL
assert event.creation_time == 1000

def test_default_ttl(self):
"""Test that DEFAULT_TTL is used correctly."""
assert NodeFailedOverEvent.DEFAULT_TTL == 5
event = NodeFailedOverEvent(id=1)
assert event.ttl == 5

def test_repr(self):
"""Test NodeFailedOverEvent string representation."""
with patch("time.monotonic", return_value=1000):
event = NodeFailedOverEvent(id=1)

with patch("time.monotonic", return_value=1001): # 1 second later
repr_str = repr(event)
assert "NodeFailedOverEvent" in repr_str
assert "id=1" in repr_str
assert "ttl=5" in repr_str
assert "remaining=4.0s" in repr_str
assert "expired=False" in repr_str

def test_equality_and_hash(self):
"""Test equality and hash for NodeFailedOverEvent."""
event1 = NodeFailedOverEvent(id=1)
event2 = NodeFailedOverEvent(id=1) # Same id
event3 = NodeFailedOverEvent(id=2) # Different id

assert event1 == event2
assert event1 != event3
assert hash(event1) == hash(event2)
assert hash(event1) != hash(event3)


class TestMaintenanceEventsConfig:
"""Test the MaintenanceEventsConfig class."""

Expand Down Expand Up @@ -477,19 +558,41 @@ def test_handle_event_migrating(self):
"""Test handling of NodeMigratingEvent."""
event = NodeMigratingEvent(id=1, ttl=5)

with patch.object(self.handler, "handle_migrating_event") as mock_handle:
with patch.object(
self.handler, "handle_maintenance_start_event"
) as mock_handle:
self.handler.handle_event(event)
mock_handle.assert_called_once_with(event)
mock_handle.assert_called_once_with(MaintenanceState.MIGRATING)

def test_handle_event_migrated(self):
"""Test handling of NodeMigratedEvent."""
event = NodeMigratedEvent(id=1)

with patch.object(
self.handler, "handle_migration_completed_event"
self.handler, "handle_maintenance_completed_event"
) as mock_handle:
self.handler.handle_event(event)
mock_handle.assert_called_once_with(event)
mock_handle.assert_called_once_with()

def test_handle_event_failing_over(self):
"""Test handling of NodeFailingOverEvent."""
event = NodeFailingOverEvent(id=1, ttl=5)

with patch.object(
self.handler, "handle_maintenance_start_event"
) as mock_handle:
self.handler.handle_event(event)
mock_handle.assert_called_once_with(MaintenanceState.FAILING_OVER)

def test_handle_event_failed_over(self):
"""Test handling of NodeFailedOverEvent."""
event = NodeFailedOverEvent(id=1)

with patch.object(
self.handler, "handle_maintenance_completed_event"
) as mock_handle:
self.handler.handle_event(event)
mock_handle.assert_called_once_with()

def test_handle_event_unknown_type(self):
"""Test handling of unknown event type."""
Expand All @@ -500,43 +603,71 @@ def test_handle_event_unknown_type(self):
result = self.handler.handle_event(event)
assert result is None

def test_handle_migrating_event_disabled(self):
"""Test migrating event handling when relax timeouts are disabled."""
def test_handle_maintenance_start_event_disabled(self):
"""Test maintenance start event handling when relax timeouts are disabled."""
config = MaintenanceEventsConfig(relax_timeout=-1)
handler = MaintenanceEventConnectionHandler(self.mock_connection, config)
event = NodeMigratingEvent(id=1, ttl=5)

result = handler.handle_migrating_event(event)
result = handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
assert result is None
self.mock_connection.update_current_socket_timeout.assert_not_called()

def test_handle_migrating_event_success(self):
"""Test successful migrating event handling."""
event = NodeMigratingEvent(id=1, ttl=5)
def test_handle_maintenance_start_event_moving_state(self):
"""Test maintenance start event handling when connection is in MOVING state."""
self.mock_connection.maintenance_state = MaintenanceState.MOVING

self.handler.handle_migrating_event(event)
result = self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
assert result is None
self.mock_connection.update_current_socket_timeout.assert_not_called()

def test_handle_maintenance_start_event_migrating_success(self):
"""Test successful maintenance start event handling for migrating."""
self.mock_connection.maintenance_state = MaintenanceState.NONE

self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)

assert self.mock_connection.maintenance_state == MaintenanceState.MIGRATING
self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)
self.mock_connection.set_tmp_settings.assert_called_once_with(
tmp_relax_timeout=20
)

def test_handle_migration_completed_event_disabled(self):
"""Test migration completed event handling when relax timeouts are disabled."""
def test_handle_maintenance_start_event_failing_over_success(self):
"""Test successful maintenance start event handling for failing over."""
self.mock_connection.maintenance_state = MaintenanceState.NONE

self.handler.handle_maintenance_start_event(MaintenanceState.FAILING_OVER)

assert self.mock_connection.maintenance_state == MaintenanceState.FAILING_OVER
self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)
self.mock_connection.set_tmp_settings.assert_called_once_with(
tmp_relax_timeout=20
)

def test_handle_maintenance_completed_event_disabled(self):
"""Test maintenance completed event handling when relax timeouts are disabled."""
config = MaintenanceEventsConfig(relax_timeout=-1)
handler = MaintenanceEventConnectionHandler(self.mock_connection, config)
event = NodeMigratedEvent(id=1)

result = handler.handle_migration_completed_event(event)
result = handler.handle_maintenance_completed_event()
assert result is None
self.mock_connection.update_current_socket_timeout.assert_not_called()

def test_handle_migration_completed_event_success(self):
"""Test successful migration completed event handling."""
event = NodeMigratedEvent(id=1)
def test_handle_maintenance_completed_event_moving_state(self):
"""Test maintenance completed event handling when connection is in MOVING state."""
self.mock_connection.maintenance_state = MaintenanceState.MOVING

result = self.handler.handle_maintenance_completed_event()
assert result is None
self.mock_connection.update_current_socket_timeout.assert_not_called()

def test_handle_maintenance_completed_event_success(self):
"""Test successful maintenance completed event handling."""
self.mock_connection.maintenance_state = MaintenanceState.MIGRATING

self.handler.handle_migration_completed_event(event)
self.handler.handle_maintenance_completed_event()

assert self.mock_connection.maintenance_state == MaintenanceState.NONE
self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1)
self.mock_connection.reset_tmp_settings.assert_called_once_with(
reset_relax_timeout=True
Expand Down
Loading