From fcc44e639d033101930116c6d97743a1fdd5b5f8 Mon Sep 17 00:00:00 2001 From: Pablo Galindo Date: Tue, 21 May 2024 11:51:35 -0400 Subject: [PATCH] gh-118893: Evaluate all statements in the new REPL separately --- Lib/_pyrepl/simple_interact.py | 26 ++++++++++++--- Lib/test/test_pyrepl.py | 61 +++++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/Lib/_pyrepl/simple_interact.py b/Lib/_pyrepl/simple_interact.py index 31b2097a78a2268..4cc6fd2815c1e27 100644 --- a/Lib/_pyrepl/simple_interact.py +++ b/Lib/_pyrepl/simple_interact.py @@ -30,6 +30,7 @@ import linecache import sys import code +import ast from types import ModuleType from .readline import _get_reader, multiline_input @@ -77,6 +78,26 @@ def __init__( def showtraceback(self): super().showtraceback(colorize=self.can_colorize) + def runsource(self, source, filename="", symbol="single"): + tree = ast.parse(source) + if tree.body: + *_, last_stmt = tree.body + for stmt in tree.body: + wrapper = ast.Interactive if stmt is last_stmt else ast.Module + the_symbol = symbol if stmt is last_stmt else "exec" + item = wrapper([stmt]) + try: + code = compile(item, filename, the_symbol) + except (OverflowError, ValueError): + self.showsyntaxerror(filename) + return False + + if code is None: + return True + + self.runcode(code) + return False + def run_multiline_interactive_console( mainmodule: ModuleType | None= None, future_flags: int = 0 @@ -144,10 +165,7 @@ def more_lines(unicodetext: str) -> bool: input_name = f"" linecache._register_code(input_name, statement, "") # type: ignore[attr-defined] - symbol = "single" if not contains_pasted_code else "exec" - more = console.push(_strip_final_indent(statement), filename=input_name, _symbol=symbol) # type: ignore[call-arg] - if contains_pasted_code and more: - more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single") # type: ignore[call-arg] + more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single") # type: ignore[call-arg] assert not more input_n += 1 except KeyboardInterrupt: diff --git a/Lib/test/test_pyrepl.py b/Lib/test/test_pyrepl.py index c8990b699b214cb..96da24fc19033b2 100644 --- a/Lib/test/test_pyrepl.py +++ b/Lib/test/test_pyrepl.py @@ -1,13 +1,15 @@ import itertools import os import rlcompleter -import sys import tempfile import unittest from code import InteractiveConsole from functools import partial from unittest import TestCase from unittest.mock import MagicMock, patch +from textwrap import dedent +import contextlib +import io from test.support import requires from test.support.import_helper import import_module @@ -1002,5 +1004,62 @@ def test_up_arrow_after_ctrl_r(self): self.assert_screen_equals(reader, "") +class TestSimpleInteract(unittest.TestCase): + def test_multiple_statements(self): + namespace = {} + code = dedent("""\ + class A: + def foo(self): + + + pass + + class B: + def bar(self): + pass + + a = 1 + a + """) + console = InteractiveColoredConsole(namespace, filename="") + with ( + patch.object(InteractiveColoredConsole, "showsyntaxerror") as showsyntaxerror, + patch.object(InteractiveColoredConsole, "runsource", wraps=console.runsource) as runsource, + ): + more = console.push(code, filename="", _symbol="single") # type: ignore[call-arg] + self.assertFalse(more) + showsyntaxerror.assert_not_called() + + + def test_multiple_statements_output(self): + namespace = {} + code = dedent("""\ + b = 1 + b + a = 1 + a + """) + console = InteractiveColoredConsole(namespace, filename="") + f = io.StringIO() + with contextlib.redirect_stdout(f): + more = console.push(code, filename="", _symbol="single") # type: ignore[call-arg] + self.assertFalse(more) + self.assertEqual(f.getvalue(), "1\n") + + def test_empty(self): + namespace = {} + code = "" + console = InteractiveColoredConsole(namespace, filename="") + f = io.StringIO() + with contextlib.redirect_stdout(f): + more = console.push(code, filename="", _symbol="single") # type: ignore[call-arg] + self.assertFalse(more) + self.assertEqual(f.getvalue(), "") + + +if __name__ == '__main__': + unittest.main() + + if __name__ == '__main__': unittest.main()