Skip to content

Commit 781bd06

Browse files
authored
ENG-8212: Redis Oplock implementation (#5932)
* Token manager tracks instance_id in token_to_socket * RedisTokenManager: keep local dicts globally updated via pub/sub * Implement lost+found for StateUpdate without websocket When an update is emitted for a token, but the websocket for that token is on another instance of the app, post it to the lost+found channel where other instances are listening for updates to send to their clients. * Implement `enumerate_tokens` for TokenManager Set the groundwork for being able to broadcast updates to all connected states. * Consolidate on `_get_token_owner` * fix test_connection_banner.py: expect SocketRecord JSON * Use a single lock waiter For more efficient and fair lock queueing, each StateManagerRedis uses a single task to monitor the keyspace for lock release/expire and then wakes up the next caller that was waiting in the queue (no fairness between separate processes though). Now lockers will wait for an `asyncio.Event` which is set by the redis pubsub waiter. If any locker waits longer than the lock_expiration, it will just try to get the lock in case there was some mixup with the pub/sub, the locker won't be blocked forever. * Redis Oplock implementation * When taking a lock from redis, hold it for 80% of the lock expiration timeout * While the lock is held, other events processed against the instance will use the cached in-memory copy of the state. * When the timeout expires or another process signals intention to access a locked state, flush the modifed states to redis and release the lock. Set REFLEX_OPLOCK_ENABLED=1 to use this feature * add test_background_task.py::test_fast_yielding * Implement real redis-backed test cases for lost+found * add some polling for the emit mocks since L+F doesn't happen immediately * Fix up unit tests for OPLOCK_ENABLED mode * support py3.10 * Do not track contended leases in-process Always check redis for contended leases before granting a lease. It's a bit slower, but much more reliable and avoids racy lock_expiration timeouts when contention occurs before the lease is created or when the pubsub hasn't caught up to reality. Always start _lock_update_task in __post_init__ to avoid race where the lease is granted, then contended, but the pubsub task hasn't started to catch the contention. * Add real+mock test cases for StateManagerRedis * update test_state to use mock_redis when real redis is not available * safe await cancelled task * explicitly disable oplock for basic test_redis cases * py3.10 support: asyncio.TimeoutError != TimeoutError * break out of forever tasks when event loop goes away No point in continually spamming "no running event loop" to the console. * generalize "forever" tasks to centralize exception handling/retry * remove unused arg * less racy way test_ensure_task_limit_window_passed * rename REFLEX_STATE_MANAGER_REDIS_DEBUG to match the class name
1 parent 62264f1 commit 781bd06

File tree

13 files changed

+1943
-115
lines changed

13 files changed

+1943
-115
lines changed

.github/workflows/unit_tests.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ jobs:
6262
export PYTHONUNBUFFERED=1
6363
export REFLEX_REDIS_URL=redis://localhost:6379
6464
uv run pytest tests/units --cov --no-cov-on-fail --cov-report=
65+
- name: Run unit tests w/ redis and OPLOCK_ENABLED
66+
if: ${{ matrix.os == 'ubuntu-latest' }}
67+
run: |
68+
export PYTHONUNBUFFERED=1
69+
export REFLEX_REDIS_URL=redis://localhost:6379
70+
export REFLEX_OPLOCK_ENABLED=true
71+
uv run pytest tests/units --cov --no-cov-on-fail --cov-report=
6572
# Change to explicitly install v1 when reflex-hosting-cli is compatible with v2
6673
- name: Run unit tests w/ pydantic v1
6774
run: |

reflex/environment.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,12 @@ class EnvironmentVariables:
727727
# How long to wait between automatic reload on frontend error to avoid reload loops.
728728
REFLEX_AUTO_RELOAD_COOLDOWN_TIME_MS: EnvVar[int] = env_var(10_000)
729729

730+
# Whether to enable debug logging for the redis state manager.
731+
REFLEX_STATE_MANAGER_REDIS_DEBUG: EnvVar[bool] = env_var(False)
732+
733+
# Whether to opportunistically hold the redis lock to allow fast in-memory access while uncontended.
734+
REFLEX_OPLOCK_ENABLED: EnvVar[bool] = env_var(False)
735+
730736

731737
environment = EnvironmentVariables()
732738

reflex/istate/manager/redis.py

Lines changed: 581 additions & 53 deletions
Large diffs are not rendered by default.

reflex/utils/tasks.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Helpers for managing asyncio tasks."""
2+
3+
import asyncio
4+
import time
5+
from collections.abc import Callable, Coroutine
6+
from typing import Any
7+
8+
from reflex.utils import console
9+
10+
11+
async def _run_forever(
12+
coro_function: Callable[..., Coroutine],
13+
*args: Any,
14+
suppress_exceptions: list[type[BaseException]],
15+
exception_delay: float,
16+
exception_limit: int,
17+
exception_limit_window: float,
18+
**kwargs: Any,
19+
):
20+
"""Wrapper to continuously run a coroutine function, suppressing certain exceptions.
21+
22+
Args:
23+
coro_function: The coroutine function to run.
24+
*args: The arguments to pass to the coroutine function.
25+
suppress_exceptions: The exceptions to suppress.
26+
exception_delay: The delay between retries when an exception is suppressed.
27+
exception_limit: The maximum number of suppressed exceptions within the limit window before raising.
28+
exception_limit_window: The time window in seconds for counting suppressed exceptions.
29+
**kwargs: The keyword arguments to pass to the coroutine function.
30+
"""
31+
last_regular_loop_start = 0
32+
exception_count = 0
33+
34+
while True:
35+
# Reset the exception count when the limit window has elapsed since the last non-exception loop started.
36+
if last_regular_loop_start + exception_limit_window < time.monotonic():
37+
exception_count = 0
38+
if not exception_count:
39+
last_regular_loop_start = time.monotonic()
40+
try:
41+
await coro_function(*args, **kwargs)
42+
except (asyncio.CancelledError, RuntimeError):
43+
raise
44+
except Exception as e:
45+
if any(isinstance(e, ex) for ex in suppress_exceptions):
46+
exception_count += 1
47+
if exception_count >= exception_limit:
48+
console.error(
49+
f"{coro_function.__name__}: task exceeded exception limit {exception_limit} within {exception_limit_window}s: {e}"
50+
)
51+
raise
52+
console.error(f"{coro_function.__name__}: task error suppressed: {e}")
53+
await asyncio.sleep(exception_delay)
54+
continue
55+
raise
56+
57+
58+
def ensure_task(
59+
owner: Any,
60+
task_attribute: str,
61+
coro_function: Callable[..., Coroutine],
62+
*args: Any,
63+
suppress_exceptions: list[type[BaseException]] | None = None,
64+
exception_delay: float = 1.0,
65+
exception_limit: int = 5,
66+
exception_limit_window: float = 60.0,
67+
**kwargs: Any,
68+
) -> asyncio.Task:
69+
"""Ensure that a task is running for the given coroutine function.
70+
71+
Note: if the task is already running, args and kwargs are ignored.
72+
73+
Args:
74+
owner: The owner of the task.
75+
task_attribute: The attribute name to store/retrieve the task from the owner object.
76+
coro_function: The coroutine function to run as a task.
77+
suppress_exceptions: The exceptions to log and continue when running the coroutine.
78+
exception_delay: The delay between retries when an exception is suppressed.
79+
exception_limit: The maximum number of suppressed exceptions within the limit window before raising.
80+
exception_limit_window: The time window in seconds for counting suppressed exceptions.
81+
*args: The arguments to pass to the coroutine function.
82+
**kwargs: The keyword arguments to pass to the coroutine function.
83+
84+
Returns:
85+
The asyncio task running the coroutine function.
86+
"""
87+
if suppress_exceptions is None:
88+
suppress_exceptions = []
89+
if RuntimeError in suppress_exceptions:
90+
msg = "Cannot suppress RuntimeError exceptions which may be raised by asyncio machinery."
91+
raise RuntimeError(msg)
92+
93+
task = getattr(owner, task_attribute, None)
94+
if task is None or task.done():
95+
asyncio.get_running_loop() # Ensure we're in an event loop.
96+
task = asyncio.create_task(
97+
_run_forever(
98+
coro_function,
99+
*args,
100+
suppress_exceptions=suppress_exceptions,
101+
exception_delay=exception_delay,
102+
exception_limit=exception_limit,
103+
exception_limit_window=exception_limit_window,
104+
**kwargs,
105+
),
106+
name=f"reflex_ensure_task|{type(owner).__name__}.{task_attribute}={coro_function.__name__}|{time.time()}",
107+
)
108+
setattr(owner, task_attribute, task)
109+
return task

reflex/utils/token_manager.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from reflex.istate.manager.redis import StateManagerRedis
1515
from reflex.state import BaseState, StateUpdate
1616
from reflex.utils import console, prerequisites
17+
from reflex.utils.tasks import ensure_task
1718

1819
if TYPE_CHECKING:
1920
from redis.asyncio import Redis
@@ -239,8 +240,13 @@ def _handle_socket_record_del(self, token: str) -> None:
239240
) is not None and socket_record.instance_id != self.instance_id:
240241
self.sid_to_token.pop(socket_record.sid, None)
241242

242-
async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
243+
async def _subscribe_socket_record_updates(self) -> None:
243244
"""Subscribe to Redis keyspace notifications for socket record updates."""
245+
await StateManagerRedis(
246+
state=BaseState, redis=self.redis
247+
)._enable_keyspace_notifications()
248+
redis_db = self.redis.get_connection_kwargs().get("db", 0)
249+
244250
async with self.redis.pubsub() as pubsub:
245251
await pubsub.psubscribe(
246252
f"__keyspace@{redis_db}__:{self._get_redis_key('*')}"
@@ -260,26 +266,14 @@ async def _subscribe_socket_record_updates(self, redis_db: int) -> None:
260266
elif event == "set":
261267
await self._get_token_owner(token, refresh=True)
262268

263-
async def _socket_record_updates_forever(self) -> None:
264-
"""Background task to monitor Redis keyspace notifications for socket record updates."""
265-
await StateManagerRedis(
266-
state=BaseState, redis=self.redis
267-
)._enable_keyspace_notifications()
268-
redis_db = self.redis.get_connection_kwargs().get("db", 0)
269-
while True:
270-
try:
271-
await self._subscribe_socket_record_updates(redis_db)
272-
except asyncio.CancelledError: # noqa: PERF203
273-
break
274-
except Exception as e:
275-
console.error(f"RedisTokenManager socket record update task error: {e}")
276-
277269
def _ensure_socket_record_task(self) -> None:
278270
"""Ensure the socket record updates subscriber task is running."""
279-
if self._socket_record_task is None or self._socket_record_task.done():
280-
self._socket_record_task = asyncio.create_task(
281-
self._socket_record_updates_forever()
282-
)
271+
ensure_task(
272+
owner=self,
273+
task_attribute="_socket_record_task",
274+
coro_function=self._subscribe_socket_record_updates,
275+
suppress_exceptions=[Exception],
276+
)
283277

284278
async def link_token_to_sid(self, token: str, sid: str) -> str | None:
285279
"""Link a token to a session ID with Redis-based duplicate detection.
@@ -386,23 +380,6 @@ async def _subscribe_lost_and_found_updates(
386380
record = LostAndFoundRecord(**json.loads(message["data"].decode()))
387381
await emit_update(StateUpdate(**record.update), record.token)
388382

389-
async def _lost_and_found_updates_forever(
390-
self,
391-
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
392-
):
393-
"""Background task to monitor Redis lost and found deltas.
394-
395-
Args:
396-
emit_update: The function to emit state updates.
397-
"""
398-
while True:
399-
try:
400-
await self._subscribe_lost_and_found_updates(emit_update)
401-
except asyncio.CancelledError: # noqa: PERF203
402-
break
403-
except Exception as e:
404-
console.error(f"RedisTokenManager lost and found task error: {e}")
405-
406383
def ensure_lost_and_found_task(
407384
self,
408385
emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]],
@@ -412,10 +389,13 @@ def ensure_lost_and_found_task(
412389
Args:
413390
emit_update: The function to emit state updates.
414391
"""
415-
if self._lost_and_found_task is None or self._lost_and_found_task.done():
416-
self._lost_and_found_task = asyncio.create_task(
417-
self._lost_and_found_updates_forever(emit_update)
418-
)
392+
ensure_task(
393+
owner=self,
394+
task_attribute="_lost_and_found_task",
395+
coro_function=self._subscribe_lost_and_found_updates,
396+
suppress_exceptions=[Exception],
397+
emit_update=emit_update,
398+
)
419399

420400
async def _get_token_owner(self, token: str, refresh: bool = False) -> str | None:
421401
"""Get the instance ID of the owner of a token.

tests/integration/test_background_task.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ async def handle_event_yield_only(self):
5555
yield State.increment()
5656
await asyncio.sleep(0.005)
5757

58+
@rx.event(background=True)
59+
async def fast_yielding(self):
60+
for _ in range(1000):
61+
yield State.increment()
62+
5863
@rx.event
5964
def increment(self):
6065
self.counter += 1
@@ -202,6 +207,11 @@ def index() -> rx.Component:
202207
on_click=State.disconnect_reconnect_background,
203208
id="disconnect-reconnect-background",
204209
),
210+
rx.button(
211+
"Fast Yielding",
212+
on_click=State.fast_yielding,
213+
id="fast-yielding",
214+
),
205215
rx.button("Reset", on_click=State.reset_counter, id="reset"),
206216
)
207217

@@ -451,3 +461,28 @@ def test_disconnect_reconnect(
451461
)
452462
# Final update should come through on the new websocket connection
453463
AppHarness.expect(lambda: counter.text == "3", timeout=5)
464+
465+
466+
def test_fast_yielding(
467+
background_task: AppHarness,
468+
driver: WebDriver,
469+
token: str,
470+
) -> None:
471+
"""Test that fast yielding works as expected.
472+
473+
Args:
474+
background_task: harness for BackgroundTask app.
475+
driver: WebDriver instance.
476+
token: The token for the connected client.
477+
"""
478+
assert background_task.app_instance is not None
479+
480+
# get a reference to all buttons
481+
fast_yielding_button = driver.find_element(By.ID, "fast-yielding")
482+
483+
# get a reference to the counter
484+
counter = driver.find_element(By.ID, "counter")
485+
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
486+
487+
fast_yielding_button.click()
488+
assert background_task._poll_for(lambda: counter.text == "1000", timeout=50)

tests/units/istate/__init__.py

Whitespace-only changes.

tests/units/istate/manager/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)