Skip to content

Commit

Permalink
Add new Prompt class.
Browse files Browse the repository at this point in the history
  • Loading branch information
byllyfish committed Jul 14, 2023
1 parent f635bba commit e3fc863
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 45 deletions.
3 changes: 2 additions & 1 deletion .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
],
"ignoreRegExpList": [
"# pylint:.*",
"`[a-z]+`"
"`[a-z]+`",
"__[a-z]+__"
]
}
111 changes: 111 additions & 0 deletions shellous/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"Implements the Prompt utility class."

import asyncio
from typing import Optional

from shellous.runner import Runner
from shellous.harvest import harvest_results


class Prompt:
"""Utility class to help with an interactive prompt session.
This is an experimental API.
Example:
```
cmd = sh("sh").stdin(sh.CAPTURE).stdout(sh.CAPTURE).stderr(sh.STDOUT)
async with cmd.env(PS1="??? ") as run:
prompt = Prompt(run, "??? ")
result = await prompt.send("echo hello")
assert result == "hello\n"
prompt.close()
```
"""

runner: Runner
prompt_bytes: bytes
default_timeout: Optional[float]

def __init__(
self,
runner: Runner,
prompt: str,
*,
default_timeout: Optional[float] = None,
):
assert runner.stdin is not None
assert runner.stdout is not None

self.runner = runner
self.prompt_bytes = prompt.encode("utf-8")
self.default_timeout = default_timeout

async def send(
self,
input_text: str = "",
*,
timeout: float | None = None,
) -> str:
"Write some input text to stdin, then await the response."
stdin = self.runner.stdin
stdout = self.runner.stdout
assert stdin is not None
assert stdout is not None

if timeout is None:
timeout = self.default_timeout

if input_text:
stdin.write(input_text.encode("utf-8") + b"\n")

# Drain our write to stdin, and wait for prompt from stdout.
cancelled, (buf, _) = await harvest_results(
_read_until(stdout, self.prompt_bytes),
stdin.drain(),
timeout=timeout,
)
assert isinstance(buf, bytes)

if cancelled:
raise asyncio.CancelledError()

# Clean up the output to remove the prompt, then return as string.
buf = buf.replace(b"\r\n", b"\n")
if buf.endswith(self.prompt_bytes):
prompt_len = len(self.prompt_bytes)
buf = buf[0:-prompt_len].rstrip(b"\n")

return buf.decode("utf-8")

def close(self):
"Close stdin to end the prompt session."
assert self.runner.stdin is not None
self.runner.stdin.close()


async def _read_until(stream: asyncio.StreamReader, separator: bytes) -> bytes:
"Read all data until the separator."
try:
# Most reads can complete without buffering.
return await stream.readuntil(separator)
except asyncio.IncompleteReadError as ex:
return ex.partial
except asyncio.LimitOverrunError as ex:
# Okay, we have to buffer.
buf = bytearray(await stream.read(ex.consumed))

while True:
try:
buf.extend(await stream.readuntil(separator))
except asyncio.IncompleteReadError as ex:
buf.extend(ex.partial)
except asyncio.LimitOverrunError as ex:
buf.extend(await stream.read(ex.consumed))
continue
break

return bytes(buf)
58 changes: 14 additions & 44 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,11 @@
import pytest

from shellous import sh
from shellous.harvest import harvest_results
from shellous.log import LOGGER
from shellous.prompt import Prompt

_PS1 = "<+==+>"


class Prompt:
"Utility class to help with an interactive prompt."

def __init__(self, stdin, stdout):
self.stdin = stdin
self.stdout = stdout
self.prompt_bytes = _PS1.encode("utf-8")

async def prompt(self, input_text=""):
"Write some input text to stdin, then await the response."
if input_text:
self.stdin.write(input_text.encode("utf-8") + b"\n")

# Drain our write to stdin, and wait for prompt from stdout.
cancelled, (buf, _) = await harvest_results(
self.stdout.readuntil(self.prompt_bytes),
self.stdin.drain(),
timeout=5.0,
)
assert isinstance(buf, bytes)

if cancelled:
raise asyncio.CancelledError()

# Clean up the output to remove the prompt, then return as string.
buf = buf.replace(b"\r\n", b"\n")
assert buf.endswith(self.prompt_bytes)
promptlen = len(self.prompt_bytes)
buf = buf[0:-promptlen].rstrip(b"\n")

return buf.decode("utf-8")
# Custom Python REPL prompt.
_PROMPT = "<+=+=+>"


async def run_asyncio_repl(cmds, logfile=None):
Expand All @@ -59,31 +27,33 @@ async def run_asyncio_repl(cmds, logfile=None):
)

async with repl as run:
assert run.stdin is not None
prompt = Prompt(run, _PROMPT, default_timeout=5.0)

p = Prompt(run.stdin, run.stdout)
await p.prompt(f"import sys; sys.ps1 = '{_PS1}'; sys.ps2 = ''")
# Customize the python REPL prompt to make it easier to detect. The
# initial output of the REPL will include the old ">>> " prompt which
# we can ignore.
await prompt.send(f"import sys; sys.ps1 = '{_PROMPT}'; sys.ps2 = ''")

# Optionally redirect logging to a file.
await p.prompt("import shellous.log, logging")
await prompt.send("import shellous.log, logging")
if logfile:
await p.prompt("shellous.log.LOGGER.setLevel(logging.DEBUG)")
await p.prompt(
await prompt.send("shellous.log.LOGGER.setLevel(logging.DEBUG)")
await prompt.send(
f"logging.basicConfig(filename='{logfile}', level=logging.DEBUG)"
)
else:
# I don't want random logging messages to confuse the output.
await p.prompt("shellous.log.LOGGER.setLevel(logging.ERROR)")
await prompt.send("shellous.log.LOGGER.setLevel(logging.ERROR)")

output = []
for cmd in cmds:
LOGGER.info(" repl: %r", cmd)
output.append(await p.prompt(cmd))
output.append(await prompt.send(cmd))
# Give tasks a chance to get started.
if ".create_task(" in cmd:
await asyncio.sleep(0.1)

run.stdin.close()
prompt.close()

result = run.result()
assert result.exit_code == 0
Expand Down

0 comments on commit e3fc863

Please sign in to comment.