Skip to content

Commit

Permalink
fix: faster module watcher tests (#3062)
Browse files Browse the repository at this point in the history
  • Loading branch information
mscolnick authored Dec 5, 2024
1 parent 9c8d11a commit 2de1e49
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
9 changes: 8 additions & 1 deletion marimo/_runtime/reload/module_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def _check_modules(
return stale_modules


MODULE_WATCHER_SLEEP_INTERVAL = 1.0

# For testing only - do not use in production
_TEST_SLEEP_INTERVAL: float | None = None


def watch_modules(
graph: dataflow.DirectedGraph,
reloader: ModuleReloader,
Expand All @@ -131,6 +137,7 @@ def watch_modules(
# work with a copy to avoid race conditions
# in CPython, dict.copy() is atomic
sys_modules = sys.modules.copy()
sleep_interval = _TEST_SLEEP_INTERVAL or MODULE_WATCHER_SLEEP_INTERVAL
while not should_exit.is_set():
# Collect the modules used by each cell
modules: dict[str, types.ModuleType] = {}
Expand Down Expand Up @@ -186,7 +193,7 @@ def watch_modules(
# Don't proceed until enqueue_run_stale_cells() has been processed,
# ie until stale cells have been rerun
run_is_processed.wait()
time.sleep(1)
time.sleep(sleep_interval)
# Update our snapshot of sys.modules
sys_modules = sys.modules.copy()

Expand Down
2 changes: 1 addition & 1 deletion tests/_runtime/reload/reload_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ def update_file(path: pathlib.Path, code: str) -> None:
(because that is stored in the file). The only reliable way
to achieve this seems to be to sleep.
"""
time.sleep(1.5)
time.sleep(1.05)
path.write_text(textwrap.dedent(code))
35 changes: 24 additions & 11 deletions tests/_runtime/reload/test_module_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@
from marimo._runtime.runtime import Kernel
from tests.conftest import ExecReqProvider

INTERVAL = 0.1


@pytest.fixture(autouse=True)
def _setup_test_sleep():
"""Automatically set up faster sleep interval for all tests in this module"""
import marimo._runtime.reload.module_watcher as mw

old_interval = mw._TEST_SLEEP_INTERVAL
mw._TEST_SLEEP_INTERVAL = INTERVAL
yield
mw._TEST_SLEEP_INTERVAL = old_interval


# these tests use random filenames for modules because they share
# the same sys.modules object, and each test needs fresh modules
Expand Down Expand Up @@ -57,7 +70,7 @@ def foo():
)

# wait for the watcher to pick up the change
await asyncio.sleep(2.5)
await asyncio.sleep(INTERVAL * 3)
assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
Expand Down Expand Up @@ -117,7 +130,7 @@ def foo():
)

# wait for the watcher to pick up the change
await asyncio.sleep(2.5)
await asyncio.sleep(INTERVAL * 3)
assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
Expand Down Expand Up @@ -172,7 +185,7 @@ def foo():
update_file(nested_module, "func = lambda : 2")

# wait for the watcher to pick up the change
await asyncio.sleep(2.5)
await asyncio.sleep(INTERVAL * 3)
assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
Expand Down Expand Up @@ -227,7 +240,7 @@ def foo():
update_file(nested_module, "func = lambda : 2")

# wait for the watcher to pick up the change
await asyncio.sleep(2.5)
await asyncio.sleep(INTERVAL * 3)
assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
Expand Down Expand Up @@ -334,7 +347,7 @@ async def test_reload_package(
update_file(nested_module, "func = lambda : 2")

# wait for the watcher to pick up the change
await asyncio.sleep(2.5)
await asyncio.sleep(INTERVAL * 3)
assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
Expand Down Expand Up @@ -390,7 +403,7 @@ def foo():
)

# wait for the watcher to pick up the change
await asyncio.sleep(2.5)
await asyncio.sleep(INTERVAL * 3)
assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
Expand Down Expand Up @@ -439,7 +452,7 @@ def foo():
)

# wait for the watcher to pick up the change
await asyncio.sleep(2.5)
await asyncio.sleep(INTERVAL * 3)
assert k.graph.cells[er_1.cell_id].stale
assert k.graph.cells[er_2.cell_id].stale
assert not k.graph.cells[er_3.cell_id].stale
Expand Down Expand Up @@ -502,10 +515,10 @@ def foo():
)

# wait for the watcher to pick up the change
elapsed = 0
while elapsed < 10:
await asyncio.sleep(1)
elapsed += 1
retries = 0
while retries < 10:
await asyncio.sleep(INTERVAL)
retries += 1
if k.graph.cells[er_1.cell_id].stale:
break

Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def teardown(self) -> None:
teardown_context()
self.stdout._watcher.stop()
self.stderr._watcher.stop()
if self.k.module_watcher is not None:
self.k.module_watcher.stop()
sys.modules["__main__"] = self._main


Expand Down

0 comments on commit 2de1e49

Please sign in to comment.