Skip to content

Commit

Permalink
Fix task-related issue with context_aenter/context_aexit.
Browse files Browse the repository at this point in the history
  • Loading branch information
byllyfish committed Aug 2, 2023
1 parent 0e4d810 commit b350441
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 31 deletions.
4 changes: 2 additions & 2 deletions shellous/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,8 @@ class Runner:
"""Runner is an asynchronous context manager that runs a command.
```
async with cmd.run() as run:
# process run.stdin, run.stdout, run.stderr (if not None)
async with Runner(cmd) as run:
# process streams: run.stdin, run.stdout, run.stderr (if not None)
result = run.result()
```
"""
Expand Down
15 changes: 9 additions & 6 deletions shellous/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from .log import LOG_DETAIL, LOGGER, log_timer

_T = TypeVar("_T")
_Key = tuple[int, int]

# Stores current stack of context managers for immutable Command objects.
_CONTEXT_STACKS = contextvars.ContextVar[
Optional[dict[int, list[AsyncContextManager[_T]]]]
Optional[dict[_Key, list[AsyncContextManager[_T]]]]
]("shellous.context_stacks", default=None)

# True if OS is derived from BSD.
Expand Down Expand Up @@ -152,11 +153,12 @@ async def context_aenter(scope: int, ctxt_manager: AsyncContextManager[_T]) -> _
"Enter an async context manager."
ctxt_stacks = _CONTEXT_STACKS.get()
if ctxt_stacks is None:
ctxt_stacks = defaultdict[int, list[Any]](list)
ctxt_stacks = defaultdict[_Key, list[Any]](list)
_CONTEXT_STACKS.set(ctxt_stacks)

result = await ctxt_manager.__aenter__() # pylint: disable=unnecessary-dunder-call
stack = ctxt_stacks[scope]
task = asyncio.current_task()
stack = ctxt_stacks[scope, id(task)]
stack.append(ctxt_manager)

return result
Expand All @@ -173,10 +175,11 @@ async def context_aexit(
if ctxt_stacks is None:
raise RuntimeError("context var `shellous.context_stacks` is missing")

stack = ctxt_stacks[scope]
ctxt_manager = stack.pop()
task = asyncio.current_task()
stack = ctxt_stacks[scope, id(task)]
ctxt_manager = stack.pop() # FIXME: This can fail if exit task is different.
if not stack:
del ctxt_stacks[scope]
del ctxt_stacks[scope, id(task)]
if not ctxt_stacks:
_CONTEXT_STACKS.set(None)

Expand Down
23 changes: 0 additions & 23 deletions tests/test_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,3 @@ async def echo_workaround():

async def test_echo_workaround(echo_workaround):
assert echo_workaround.command.args[0] == "echo"


@contextlib.contextmanager
def _preserve_contextvars():
"Context manager that copies the `context_stack` context var."
old_context = contextvars.copy_context()
yield
new_context = contextvars.copy_context()
for var in old_context:
if var.name.startswith("shellous.") and var not in new_context:
var.set(old_context[var])


@pytest.fixture
async def echo_preserved():
"Another work-around explicitly copies the contextvars (a bit hacky)."
async with sh("echo") as run:
with _preserve_contextvars():
yield run


async def test_echo_preserved(echo_preserved):
assert echo_preserved.command.args[0] == "echo"
40 changes: 40 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"Unit tests for functions in util module."

import asyncio
import contextlib

import pytest

from shellous.util import (
EnvironmentDict,
close_fds,
coerce_env,
context_aenter,
context_aexit,
decode_bytes,
uninterrupted,
verify_dev_fd,
Expand Down Expand Up @@ -103,3 +106,40 @@ def test_environment_dict():
# EnvironmentDict is immutable.
with pytest.raises(TypeError):
d1["b"] = "2" # pyright: ignore[reportGeneralTypeIssues]


class _TestContextHelpers:
def __init__(self):
self.idx = 0
self.log = []

async def __aenter__(self):
self.idx += 1
return await context_aenter(id(self), self._ctxt(self.idx))

async def __aexit__(self, *args):
return await context_aexit(id(self), *args)

@contextlib.asynccontextmanager
async def _ctxt(self, idx):
self.log.append(f"enter {idx}")
yield self
self.log.append(f"exit {idx}")


async def test_context_helpers():
"""Test context manager helper functions are re-entrant even in overlapping
tasks."""
tc = _TestContextHelpers()

async def _task():
async with tc:
await asyncio.sleep(0.2)

async with tc:
task1 = asyncio.create_task(_task())
await asyncio.sleep(0.01)
await task1

# Note that child task outlives its parent task's context manager.
assert tc.log == ["enter 1", "enter 2", "exit 1", "exit 2"]

0 comments on commit b350441

Please sign in to comment.