diff --git a/src/python/pants/engine/console.py b/src/python/pants/engine/console.py index 3c7acb09e2c..0aa95b586d6 100644 --- a/src/python/pants/engine/console.py +++ b/src/python/pants/engine/console.py @@ -1,8 +1,9 @@ # Copyright 2018 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +from __future__ import annotations import sys -from typing import Callable, Optional +from typing import Callable, Optional, TextIO from colors import blue, cyan, green, magenta, red, yellow @@ -18,35 +19,45 @@ class Console: def __init__( self, - stdout=None, - stderr=None, + stdin: TextIO | None = None, + stdout: TextIO | None = None, + stderr: TextIO | None = None, use_colors: bool = True, session: Optional[SchedulerSession] = None, ): - """`stdout` and `stderr` may be explicitly provided when Console is constructed. - - We use this in tests to provide a mock we can write tests against, rather than writing to - the system stdout/stderr. If a SchedulerSession is set, any running UI will be torn down - before stdio is rendered. - """ + """If a SchedulerSession is set, any running UI will be torn down before stdio is + rendered.""" + self._stdin = stdin or sys.stdin self._stdout = stdout or sys.stdout self._stderr = stderr or sys.stderr self._use_colors = use_colors self._session = session @property - def stdout(self): + def stdin(self) -> TextIO: + if self._session: + self._session.teardown_dynamic_ui() + return self._stdin + + @property + def stdout(self) -> TextIO: if self._session: self._session.teardown_dynamic_ui() return self._stdout @property - def stderr(self): + def stderr(self) -> TextIO: if self._session: self._session.teardown_dynamic_ui() return self._stderr + def input(self, prompt: str | None = None) -> str: + """Equivalent to the `input` builtin, but clears any running UI before rendering.""" + if prompt is not None: + self.write_stdout(prompt) + return self.stdin.readline().rstrip("\n") + def write_stdout(self, payload: str) -> None: self.stdout.write(payload) diff --git a/src/python/pants/engine/internals/native.py b/src/python/pants/engine/internals/native.py index ba621199531..1cf60768e3b 100644 --- a/src/python/pants/engine/internals/native.py +++ b/src/python/pants/engine/internals/native.py @@ -3,7 +3,6 @@ from __future__ import annotations -import os from typing import Dict, Iterable, List, Optional, Tuple, cast from typing_extensions import Protocol diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index 601437c1a49..96cd3229ab6 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -1,6 +1,7 @@ # Copyright 2021 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +from io import RawIOBase from typing import Any, Dict, List, TextIO, Tuple # TODO: black and flake8 disagree about the content of this file: @@ -19,7 +20,7 @@ def stdio_initialize( log_levels_by_target: Dict[str, int], message_regex_filters: Tuple[str, ...], log_file: str, -) -> Tuple[TextIO, TextIO, TextIO]: ... +) -> Tuple[RawIOBase, TextIO, TextIO]: ... def stdio_thread_get_destination() -> PyStdioDestination: ... def stdio_thread_set_destination(destination: PyStdioDestination) -> None: ... def stdio_thread_console_set(stdin_fileno: int, stdout_fileno: int, stderr_fileno: int) -> None: ... diff --git a/src/python/pants/init/logging.py b/src/python/pants/init/logging.py index 14d595a7577..7178ec50761 100644 --- a/src/python/pants/init/logging.py +++ b/src/python/pants/init/logging.py @@ -2,10 +2,12 @@ # Licensed under the Apache License, Version 2.0 (see LICENSE). import http.client +import locale import logging import os import sys from contextlib import contextmanager +from io import BufferedReader, TextIOWrapper from logging import Formatter, LogRecord, StreamHandler from typing import Dict, Iterator @@ -154,7 +156,7 @@ def initialize_stdio(global_bootstrap_options: OptionValueContainer) -> Iterator # Initialize thread-local stdio, and replace sys.std* with proxies. original_stdin, original_stdout, original_stderr = sys.stdin, sys.stdout, sys.stderr try: - sys.stdin, sys.stdout, sys.stderr = native_engine.stdio_initialize( + raw_stdin, sys.stdout, sys.stderr = native_engine.stdio_initialize( global_level.level, log_show_rust_3rdparty, use_color, @@ -163,6 +165,14 @@ def initialize_stdio(global_bootstrap_options: OptionValueContainer) -> Iterator tuple(message_regex_filters), log_path, ) + sys.stdin = TextIOWrapper( + BufferedReader(raw_stdin), + # NB: We set the default encoding explicitly to bypass logic in the TextIOWrapper + # constructor that would poke the underlying file (which is not valid until a + # `stdio_destination` is set). + encoding=locale.getpreferredencoding(False), + ) + sys.__stdin__, sys.__stdout__, sys.__stderr__ = sys.stdin, sys.stdout, sys.stderr # Install a Python logger that will route through the Rust logger. with _python_logging_setup(global_level, print_stacktrace): diff --git a/src/python/pants/testutil/rule_runner.py b/src/python/pants/testutil/rule_runner.py index 2a1f87a96d3..4d6ff9fef77 100644 --- a/src/python/pants/testutil/rule_runner.py +++ b/src/python/pants/testutil/rule_runner.py @@ -462,13 +462,24 @@ def get(product, subject): @contextmanager def mock_console( options_bootstrapper: OptionsBootstrapper, + *, + stdin_content: bytes | str | None = None, ) -> Iterator[Tuple[Console, StdioReader]]: global_bootstrap_options = options_bootstrapper.bootstrap_options.for_global_scope() - with initialize_stdio(global_bootstrap_options), open( - "/dev/null", "r" - ) as stdin, temporary_file(binary_mode=False) as stdout, temporary_file( + + @contextmanager + def stdin_context(): + if stdin_content is None: + yield open("/dev/null", "r") + else: + with temporary_file(binary_mode=isinstance(stdin_content, bytes)) as stdin_file: + stdin_file.write(stdin_content) + stdin_file.close() + yield open(stdin_file.name, "r") + + with initialize_stdio(global_bootstrap_options), stdin_context() as stdin, temporary_file( binary_mode=False - ) as stderr, stdio_destination( + ) as stdout, temporary_file(binary_mode=False) as stderr, stdio_destination( stdin_fileno=stdin.fileno(), stdout_fileno=stdout.fileno(), stderr_fileno=stderr.fileno(), diff --git a/src/rust/engine/src/externs/stdio.rs b/src/rust/engine/src/externs/stdio.rs index 3f92b627eed..5853c511365 100644 --- a/src/rust/engine/src/externs/stdio.rs +++ b/src/rust/engine/src/externs/stdio.rs @@ -35,6 +35,7 @@ clippy::zero_ptr )] +use cpython::buffer::PyBuffer; use cpython::{exc, py_class, PyErr, PyObject, PyResult, Python}; /// @@ -64,6 +65,31 @@ py_class!(pub class PyStdioRead |py| { def fileno(&self) -> PyResult { stdio::get_destination().stdin_as_raw_fd().map_err(|e| PyErr::new::(py, (e,))) } + + def readinto(&self, obj: PyObject) -> PyResult { + let py_buffer = PyBuffer::get(py, &obj)?; + let mut buffer = vec![0; py_buffer.len_bytes() as usize]; + let read = py.allow_threads(|| { + stdio::get_destination().read_stdin(&mut buffer) + }).map_err(|e| PyErr::new::(py, (e.to_string(),)))?; + // NB: `as_mut_slice` exposes a `&[Cell]`, which we can't use directly in `read`. We use + // `copy_from_slice` instead, which unfortunately involves some extra copying. + py_buffer.copy_from_slice(py, &buffer)?; + Ok(read) + } + + @property + def closed(&self) -> PyResult { + Ok(false) + } + + def readable(&self) -> PyResult { + Ok(true) + } + + def seekable(&self) -> PyResult { + Ok(false) + } }); /// diff --git a/tests/python/pants_test/init/test_logging.py b/tests/python/pants_test/init/test_logging.py index e0e92dcea5e..dbbfbca3396 100644 --- a/tests/python/pants_test/init/test_logging.py +++ b/tests/python/pants_test/init/test_logging.py @@ -7,6 +7,7 @@ from pants.engine.internals import native_engine from pants.init.logging import initialize_stdio from pants.testutil.option_util import create_options_bootstrapper +from pants.testutil.rule_runner import mock_console from pants.util.contextutil import temporary_dir from pants.util.logging import LogLevel @@ -62,3 +63,16 @@ def test_log_filtering_by_rule() -> None: assert "[INFO] log msg one" in loglines[0] assert "[DEBUG] log msg three" in loglines[1] assert len(loglines) == 2 + + +def test_stdin_input() -> None: + ob = create_options_bootstrapper([]) + expected_input = "my_input" + expected_output = "my_output" + with mock_console(ob, stdin_content=f"{expected_input}\n") as (_, stdio_reader): + assert expected_input == input(expected_output) + assert expected_output == stdio_reader.get_stdout() + + with mock_console(ob, stdin_content=f"{expected_input}\n") as (console, stdio_reader): + assert expected_input == console.input(expected_output) + assert expected_output == stdio_reader.get_stdout()