Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a trio repl #2972

Merged
merged 17 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,66 @@ explicit and might be easier to reason about.
``contextvars``.


.. _interactive debugging:


Interactive debugging
---------------------

When you start an interactive Python session to debug any async program
(whether it's based on ``asyncio``, Trio, or something else), every await
expression needs to be inside an async function:

.. code-block:: console

$ python
Python 3.10.6
Type "help", "copyright", "credits" or "license" for more information.
>>> import trio
>>> await trio.sleep(1)
File "<stdin>", line 1
SyntaxError: 'await' outside function
>>> async def main():
... print("hello...")
... await trio.sleep(1)
... print("world!")
...
>>> trio.run(main)
hello...
world!

This can make it difficult to iterate quickly since you have to redefine the
whole function body whenever you make a tweak.

Trio provides a modified interactive console that lets you ``await`` at the top
level. You can access this console by running ``python -m trio``:

.. code-block:: console

$ python -m trio
Trio 0.21.0+dev, Python 3.10.6
Use "await" directly instead of "trio.run()".
Type "help", "copyright", "credits" or "license" for more information.
>>> import trio
>>> print("hello..."); await trio.sleep(1); print("world!")
hello...
world!

If you are an IPython user, you can use IPython's `autoawait
<https://ipython.readthedocs.io/en/stable/interactive/magics.html#magic-autoawait>`__
function. This can be enabled within the IPython shell by running the magic command
``%autoawait trio``. To have ``autoawait`` enabled whenever Trio installed, you can
add the following to your IPython startup files.
(e.g. ``~/.ipython/profile_default/startup/10-async.py``)

.. code-block::

try:
import trio
get_ipython().run_line_magic("autoawait", "trio")
except ImportError:
pass

Exceptions and warnings
-----------------------

Expand Down
16 changes: 16 additions & 0 deletions newsfragments/2972.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Added an interactive interpreter ``python -m trio``.

This makes it easier to try things and experiment with trio in the a Python repl.
Use the ``await`` keyword without needing to call ``trio.run()``

.. code-block:: console

$ python -m trio
Trio 0.21.0+dev, Python 3.10.6
Use "await" directly instead of "trio.run()".
Type "help", "copyright", "credits" or "license" for more information.
>>> import trio
>>> await trio.sleep(1); print("hi") # prints after one second
hi

See :ref:`interactive debugging` for further detail.
3 changes: 3 additions & 0 deletions src/trio/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from trio._repl import main

main(locals())
A5rocks marked this conversation as resolved.
Show resolved Hide resolved
96 changes: 96 additions & 0 deletions src/trio/_repl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import ast
import contextlib
import inspect
import sys
import types
import warnings
from code import InteractiveConsole

import trio
import trio.lowlevel


class TrioInteractiveConsole(InteractiveConsole):
# code.InteractiveInterpreter defines locals as Mapping[str, Any]
# but when we pass this to FunctionType it expects a dict. So
# we make the type more specific on our subclass
locals: dict[str, object]

def __init__(self, repl_locals: dict[str, object] | None = None):
super().__init__(locals=repl_locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT

def runcode(self, code: types.CodeType) -> None:
async def _runcode_in_trio() -> BaseException | None:
func = types.FunctionType(code, self.locals)
try:
coro = func()
except BaseException as e:
return e

if inspect.iscoroutine(coro):
try:
await coro
except BaseException as e:
return e
Copy link
Contributor

@A5rocks A5rocks Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can some of this exception stuff be replaced with stuff from outcome? I don't see it providing much technical benefit but it would decrease some code duplication.


Looking at outcome's code, it looks like they remove one layer off the exception's traceback where this doesn't. That might be a good thing to do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might not be understanding how outcomes works, but I don't see how it would simplify this. Wouldn't it just end up calling capture method and then immediately unwrapping?

try:
  coro = outcome.capture(func)
  coro.unwrap()
except BaseException as e:
  return e

The traceback printing is consistent with the standard python repl and the python -m asyncio so, I'm leaning towards keeping it as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well specifically I was thinking directly returning the Value/Error (and an if on whether to use capture or acapture? Which would disallow leaving out the await when calling an async function but that's already trio style. I'll make this a separate comment when I get back to looking at this). If the trace backs are the same then the only benefit would be a few less lines, which probably isn't worth the work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came up with this. It's better in some ways, but at the same time bit strange.

    def runcode(self, code: types.CodeType) -> None:
        async def _runcode_in_trio():
            func = types.FunctionType(code, self.locals)
            result = outcome.capture(func)
            
            if isinstance(result, outcome.Error):
                return result
            
            coro = result.unwrap()
            if inspect.iscoroutine(coro):
                return await outcome.acapture(lambda: coro)
        
            return outcome.Value(coro)

        try:
            value = trio.from_thread.run(_runcode_in_trio).unwrap()
        except SystemExit:
            # If it is SystemExit quit the repl. Otherwise, print the
            # traceback.
            # There could be a SystemExit inside a BaseExceptionGroup. If
            # that happens, it probably isn't the user trying to quit the
            # repl, but an error in the code. So we print the exception
            # and stay in the repl.
            raise
        except BaseException:
            self.showtraceback()

I think it is awkward because there is a single step for most expressions, but any repl input that is an await needs the second step of evaluation.

Copy link
Contributor

@A5rocks A5rocks Apr 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finally tried poking at this a bit, ultimately what I was thinking about is:

    def runcode(self, code: types.CodeType) -> None:
        async def _runcode_in_trio():
            func = types.FunctionType(code, self.locals)
            if inspect.iscoroutinefunction(func):
                return await outcome.acapture(func)
            else:
                return outcome.capture(func)

        try:
            value = trio.from_thread.run(_runcode_in_trio).unwrap()
        except SystemExit:
            # If it is SystemExit quit the repl. Otherwise, print the
            # traceback.
            # There could be a SystemExit inside a BaseExceptionGroup. If
            # that happens, it probably isn't the user trying to quit the
            # repl, but an error in the code. So we print the exception
            # and stay in the repl.
            raise
        except BaseException:
            self.showtraceback()

probably doesn't work cause I haven't tested exactly that, but inspect.iscoroutinefunction seems to do the right thing (doesn't trigger for trio.sleep(1) or async def x(): ..., but triggers for await trio.sleep(1))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, iscoroutinefunction is the bit I was missing. That is much better. Thanks for the nudge in that direction. I've got a few bits to double check but it looks like that works right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally it might not be the best idea (since it can’t detect code that just returns a coroutine object directly). But in this case we know return isn’t valid.

return None

maybe_exc_or_excgroup = trio.from_thread.run(_runcode_in_trio)

if maybe_exc_or_excgroup is not None:
# maybe_exc_or_excgroup is an exception, or an exception group.
# If it is SystemExit quit the repl. Otherwise, print the
# traceback.
# There could be a SystemExit inside a BaseExceptionGroup. If
# that happens, it probably isn't the user trying to quit the
# repl, but an error in the code. So we print the exception
# and stay in the repl.
if isinstance(maybe_exc_or_excgroup, SystemExit):
raise maybe_exc_or_excgroup

# If we didn't raise above, there was an exception, but no
# SystemExit. So we raise here and except, so that the console
# can print the traceback to the user.
try:
raise maybe_exc_or_excgroup
except BaseException:
self.showtraceback()


async def run_repl(console: TrioInteractiveConsole) -> None:
banner = (
f"trio REPL {sys.version} on {sys.platform}\n"
f'Use "await" directly instead of "trio.run()".\n'
f'Type "help", "copyright", "credits" or "license" '
f"for more information.\n"
f'{getattr(sys, "ps1", ">>> ")}import trio'
)
try:
await trio.to_thread.run_sync(console.interact, banner)
finally:
warnings.filterwarnings(
"ignore",
message=r"^coroutine .* was never awaited$",
category=RuntimeWarning,
)


def main(original_locals: dict[str, object]) -> None:
with contextlib.suppress(ImportError):
import readline # noqa: F401
CoolCat467 marked this conversation as resolved.
Show resolved Hide resolved

repl_locals: dict[str, object] = {"trio": trio}
for key in {
"__name__",
"__package__",
"__loader__",
"__spec__",
"__builtins__",
"__file__",
}:
repl_locals[key] = original_locals[key]
A5rocks marked this conversation as resolved.
Show resolved Hide resolved

console = TrioInteractiveConsole(repl_locals)
trio.run(run_repl, console)
209 changes: 209 additions & 0 deletions src/trio/_tests/test_repl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from __future__ import annotations

import subprocess
import sys
from typing import Protocol

import pytest

import trio._repl


class RawInput(Protocol):
def __call__(self, prompt: str = "") -> str: ...


def build_raw_input(cmds: list[str]) -> RawInput:
"""
Pass in a list of strings.
Returns a callable that returns each string, each time its called
When there are not more strings to return, raise EOFError
"""
cmds_iter = iter(cmds)
prompts = []

def _raw_helper(prompt: str = "") -> str:
prompts.append(prompt)
try:
return next(cmds_iter)
except StopIteration:
raise EOFError from None

return _raw_helper


def test_build_raw_input() -> None:
"""Quick test of our helper function."""
raw_input = build_raw_input(["cmd1"])
assert raw_input() == "cmd1"
with pytest.raises(EOFError):
raw_input()


# In 3.10 or later, types.FunctionType (used internally) will automatically
# attach __builtins__ to the function objects. However we need to explicitly
# include it for 3.8 & 3.9
def build_locals() -> dict[str, object]:
return {"__builtins__": __builtins__}


async def test_basic_interaction(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Run some basic commands through the interpreter while capturing stdout.
Ensure that the interpreted prints the expected results.
"""
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# evaluate simple expression and recall the value
"x = 1",
"print(f'{x=}')",
# Literal gets printed
"'hello'",
# define and call sync function
"def func():",
" print(x + 1)",
"",
"func()",
# define and call async function
"async def afunc():",
" return 4",
"",
"await afunc()",
# import works
"import sys",
"sys.stdout.write('hello stdout\\n')",
]
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"]


async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"raise SystemExit",
]
)
monkeypatch.setattr(console, "raw_input", raw_input)
with pytest.raises(SystemExit):
await trio._repl.run_repl(console)


async def test_system_exits_in_exc_group(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"import sys",
"if sys.version_info < (3, 11):",
" from exceptiongroup import BaseExceptionGroup",
"",
"raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])",
"print('AFTER BaseExceptionGroup')",
]
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
# assert that raise SystemExit in an exception group
# doesn't quit
assert "AFTER BaseExceptionGroup" in out


async def test_system_exits_in_nested_exc_group(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"import sys",
"if sys.version_info < (3, 11):",
" from exceptiongroup import BaseExceptionGroup",
"",
"raise BaseExceptionGroup(",
" '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])",
"print('AFTER BaseExceptionGroup')",
]
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
# assert that raise SystemExit in an exception group
# doesn't quit
assert "AFTER BaseExceptionGroup" in out


async def test_base_exception_captured(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# The statement after raise should still get executed
"raise BaseException",
"print('AFTER BaseException')",
]
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "AFTER BaseException" in out


async def test_exc_group_captured(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# The statement after raise should still get executed
"raise ExceptionGroup('', [KeyError()])",
"print('AFTER ExceptionGroup')",
]
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "AFTER ExceptionGroup" in out


async def test_base_exception_capture_from_coroutine(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"async def async_func_raises_base_exception():",
" raise BaseException",
"",
# This will raise, but the statement after should still
# be executed
"await async_func_raises_base_exception()",
"print('AFTER BaseException')",
]
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "AFTER BaseException" in out


def test_main_entrypoint() -> None:
"""
Basic smoke test when running via the package __main__ entrypoint.
"""
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
assert repl.returncode == 0
Loading