Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e3b9787
Token manager tracks instance_id in token_to_socket
masenf Oct 24, 2025
76630a6
RedisTokenManager: keep local dicts globally updated via pub/sub
masenf Oct 24, 2025
e71eece
Implement lost+found for StateUpdate without websocket
masenf Oct 24, 2025
096f6ac
Implement `enumerate_tokens` for TokenManager
masenf Oct 24, 2025
223747d
Consolidate on `_get_token_owner`
masenf Oct 24, 2025
07a9093
fix test_connection_banner.py: expect SocketRecord JSON
masenf Oct 24, 2025
70f73fb
Use a single lock waiter
masenf Oct 25, 2025
6c650df
Redis Oplock implementation
masenf Oct 27, 2025
b107aee
add test_background_task.py::test_fast_yielding
masenf Oct 27, 2025
d8f040e
Merge remote-tracking branch 'origin/main' into masenf/redis_lost+found
masenf Oct 27, 2025
83c8fcf
Implement real redis-backed test cases for lost+found
masenf Oct 27, 2025
36496dd
add some polling for the emit mocks since L+F doesn't happen immediately
masenf Oct 28, 2025
e1ef249
Merge remote-tracking branch 'origin/masenf/redis_lost+found' into ma…
masenf Oct 29, 2025
e20cb81
Fix up unit tests for OPLOCK_ENABLED mode
masenf Oct 29, 2025
615b11f
Merge remote-tracking branch 'origin/main' into masenf/redis_oplock
masenf Oct 29, 2025
fa8b0f4
support py3.10
masenf Nov 2, 2025
fd5283c
Do not track contended leases in-process
masenf Nov 3, 2025
da1ddb5
Add real+mock test cases for StateManagerRedis
masenf Nov 3, 2025
06c20f7
update test_state to use mock_redis when real redis is not available
masenf Nov 3, 2025
d771e3e
Merge remote-tracking branch 'origin/main' into masenf/redis_oplock
masenf Nov 3, 2025
838926f
safe await cancelled task
masenf Nov 3, 2025
b058569
explicitly disable oplock for basic test_redis cases
masenf Nov 3, 2025
a0c08c7
py3.10 support: asyncio.TimeoutError != TimeoutError
masenf Nov 3, 2025
871c885
break out of forever tasks when event loop goes away
masenf Nov 3, 2025
6a4db93
generalize "forever" tasks to centralize exception handling/retry
masenf Nov 4, 2025
474d856
remove unused arg
masenf Nov 4, 2025
d487a95
less racy way test_ensure_task_limit_window_passed
masenf Nov 4, 2025
38ea405
rename REFLEX_STATE_MANAGER_REDIS_DEBUG to match the class name
masenf Nov 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ jobs:
export PYTHONUNBUFFERED=1
export REFLEX_REDIS_URL=redis://localhost:6379
uv run pytest tests/units --cov --no-cov-on-fail --cov-report=
- name: Run unit tests w/ redis and OPLOCK_ENABLED
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
export PYTHONUNBUFFERED=1
export REFLEX_REDIS_URL=redis://localhost:6379
export REFLEX_OPLOCK_ENABLED=true
uv run pytest tests/units --cov --no-cov-on-fail --cov-report=
# Change to explicitly install v1 when reflex-hosting-cli is compatible with v2
- name: Run unit tests w/ pydantic v1
run: |
Expand Down
6 changes: 6 additions & 0 deletions reflex/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,12 @@ class EnvironmentVariables:
# How long to wait between automatic reload on frontend error to avoid reload loops.
REFLEX_AUTO_RELOAD_COOLDOWN_TIME_MS: EnvVar[int] = env_var(10_000)

# Whether to enable debug logging for the redis state manager.
REFLEX_STATE_MANAGER_REDIS_DEBUG: EnvVar[bool] = env_var(False)

# Whether to opportunistically hold the redis lock to allow fast in-memory access while uncontended.
REFLEX_OPLOCK_ENABLED: EnvVar[bool] = env_var(False)


environment = EnvironmentVariables()

Expand Down
634 changes: 581 additions & 53 deletions reflex/istate/manager/redis.py

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions reflex/utils/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Helpers for managing asyncio tasks."""

import asyncio
import time
from collections.abc import Callable, Coroutine
from typing import Any

from reflex.utils import console


async def _run_forever(
coro_function: Callable[..., Coroutine],
*args: Any,
suppress_exceptions: list[type[BaseException]],
exception_delay: float,
exception_limit: int,
exception_limit_window: float,
**kwargs: Any,
):
"""Wrapper to continuously run a coroutine function, suppressing certain exceptions.

Args:
coro_function: The coroutine function to run.
*args: The arguments to pass to the coroutine function.
suppress_exceptions: The exceptions to suppress.
exception_delay: The delay between retries when an exception is suppressed.
exception_limit: The maximum number of suppressed exceptions within the limit window before raising.
exception_limit_window: The time window in seconds for counting suppressed exceptions.
**kwargs: The keyword arguments to pass to the coroutine function.
"""
last_regular_loop_start = 0
exception_count = 0

while True:
# Reset the exception count when the limit window has elapsed since the last non-exception loop started.
if last_regular_loop_start + exception_limit_window < time.monotonic():
exception_count = 0
if not exception_count:
last_regular_loop_start = time.monotonic()
try:
await coro_function(*args, **kwargs)
except (asyncio.CancelledError, RuntimeError):
raise
except Exception as e:
if any(isinstance(e, ex) for ex in suppress_exceptions):
exception_count += 1
if exception_count >= exception_limit:
console.error(
f"{coro_function.__name__}: task exceeded exception limit {exception_limit} within {exception_limit_window}s: {e}"
)
raise
console.error(f"{coro_function.__name__}: task error suppressed: {e}")
await asyncio.sleep(exception_delay)
continue
raise


def ensure_task(
owner: Any,
task_attribute: str,
coro_function: Callable[..., Coroutine],
*args: Any,
suppress_exceptions: list[type[BaseException]] | None = None,
exception_delay: float = 1.0,
exception_limit: int = 5,
exception_limit_window: float = 60.0,
**kwargs: Any,
) -> asyncio.Task:
"""Ensure that a task is running for the given coroutine function.

Note: if the task is already running, args and kwargs are ignored.

Args:
owner: The owner of the task.
task_attribute: The attribute name to store/retrieve the task from the owner object.
coro_function: The coroutine function to run as a task.
suppress_exceptions: The exceptions to log and continue when running the coroutine.
exception_delay: The delay between retries when an exception is suppressed.
exception_limit: The maximum number of suppressed exceptions within the limit window before raising.
exception_limit_window: The time window in seconds for counting suppressed exceptions.
*args: The arguments to pass to the coroutine function.
**kwargs: The keyword arguments to pass to the coroutine function.

Returns:
The asyncio task running the coroutine function.
"""
if suppress_exceptions is None:
suppress_exceptions = []
if RuntimeError in suppress_exceptions:
msg = "Cannot suppress RuntimeError exceptions which may be raised by asyncio machinery."
raise RuntimeError(msg)

task = getattr(owner, task_attribute, None)
if task is None or task.done():
asyncio.get_running_loop() # Ensure we're in an event loop.
task = asyncio.create_task(
_run_forever(
coro_function,
*args,
suppress_exceptions=suppress_exceptions,
exception_delay=exception_delay,
exception_limit=exception_limit,
exception_limit_window=exception_limit_window,
**kwargs,
),
name=f"reflex_ensure_task|{type(owner).__name__}.{task_attribute}={coro_function.__name__}|{time.time()}",
)
setattr(owner, task_attribute, task)
return task
60 changes: 20 additions & 40 deletions reflex/utils/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from reflex.istate.manager.redis import StateManagerRedis
from reflex.state import BaseState, StateUpdate
from reflex.utils import console, prerequisites
from reflex.utils.tasks import ensure_task

if TYPE_CHECKING:
from redis.asyncio import Redis
Expand Down Expand Up @@ -239,8 +240,13 @@ def _handle_socket_record_del(self, token: str) -> None:
) is not None and socket_record.instance_id != self.instance_id:
self.sid_to_token.pop(socket_record.sid, None)

async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
async def _subscribe_socket_record_updates(self) -> None:
"""Subscribe to Redis keyspace notifications for socket record updates."""
await StateManagerRedis(
state=BaseState, redis=self.redis
)._enable_keyspace_notifications()
redis_db = self.redis.get_connection_kwargs().get("db", 0)

async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(
f"__keyspace@{redis_db}__:{self._get_redis_key('*')}"
Expand All @@ -260,26 +266,14 @@ async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
elif event == "set":
await self._get_token_owner(token, refresh=True)

async def _socket_record_updates_forever(self) -> None:
"""Background task to monitor Redis keyspace notifications for socket record updates."""
await StateManagerRedis(
state=BaseState, redis=self.redis
)._enable_keyspace_notifications()
redis_db = self.redis.get_connection_kwargs().get("db", 0)
while True:
try:
await self._subscribe_socket_record_updates(redis_db)
except asyncio.CancelledError: # noqa: PERF203
break
except Exception as e:
console.error(f"RedisTokenManager socket record update task error: {e}")

def _ensure_socket_record_task(self) -> None:
"""Ensure the socket record updates subscriber task is running."""
if self._socket_record_task is None or self._socket_record_task.done():
self._socket_record_task = asyncio.create_task(
self._socket_record_updates_forever()
)
ensure_task(
owner=self,
task_attribute="_socket_record_task",
coro_function=self._subscribe_socket_record_updates,
suppress_exceptions=[Exception],
)

async def link_token_to_sid(self, token: str, sid: str) -> str | None:
"""Link a token to a session ID with Redis-based duplicate detection.
Expand Down Expand Up @@ -386,23 +380,6 @@ async def _subscribe_lost_and_found_updates(
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
await emit_update(StateUpdate(**record.update), record.token)

async def _lost_and_found_updates_forever(
self,
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
):
"""Background task to monitor Redis lost and found deltas.

Args:
emit_update: The function to emit state updates.
"""
while True:
try:
await self._subscribe_lost_and_found_updates(emit_update)
except asyncio.CancelledError: # noqa: PERF203
break
except Exception as e:
console.error(f"RedisTokenManager lost and found task error: {e}")

def ensure_lost_and_found_task(
self,
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
Expand All @@ -412,10 +389,13 @@ def ensure_lost_and_found_task(
Args:
emit_update: The function to emit state updates.
"""
if self._lost_and_found_task is None or self._lost_and_found_task.done():
self._lost_and_found_task = asyncio.create_task(
self._lost_and_found_updates_forever(emit_update)
)
ensure_task(
owner=self,
task_attribute="_lost_and_found_task",
coro_function=self._subscribe_lost_and_found_updates,
suppress_exceptions=[Exception],
emit_update=emit_update,
)

async def _get_token_owner(self, token: str, refresh: bool = False) -> str | None:
"""Get the instance ID of the owner of a token.
Expand Down
35 changes: 35 additions & 0 deletions tests/integration/test_background_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ async def handle_event_yield_only(self):
yield State.increment()
await asyncio.sleep(0.005)

@rx.event(background=True)
async def fast_yielding(self):
for _ in range(1000):
yield State.increment()

@rx.event
def increment(self):
self.counter += 1
Expand Down Expand Up @@ -202,6 +207,11 @@ def index() -> rx.Component:
on_click=State.disconnect_reconnect_background,
id="disconnect-reconnect-background",
),
rx.button(
"Fast Yielding",
on_click=State.fast_yielding,
id="fast-yielding",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)

Expand Down Expand Up @@ -451,3 +461,28 @@ def test_disconnect_reconnect(
)
# Final update should come through on the new websocket connection
AppHarness.expect(lambda: counter.text == "3", timeout=5)


def test_fast_yielding(
background_task: AppHarness,
driver: WebDriver,
token: str,
) -> None:
"""Test that fast yielding works as expected.

Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None

# get a reference to all buttons
fast_yielding_button = driver.find_element(By.ID, "fast-yielding")

# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)

fast_yielding_button.click()
assert background_task._poll_for(lambda: counter.text == "1000", timeout=50)
Empty file added tests/units/istate/__init__.py
Empty file.
Empty file.
Loading