diff --git a/Lib/pdb.py b/Lib/pdb.py index 0de8bbe37e471e..3a21579b5bbe11 100644 --- a/Lib/pdb.py +++ b/Lib/pdb.py @@ -77,6 +77,7 @@ import json import token import types +import atexit import codeop import pprint import signal @@ -92,11 +93,12 @@ import itertools import traceback import linecache +import selectors +import threading import _colorize import _pyrepl.utils -from contextlib import closing -from contextlib import contextmanager +from contextlib import ExitStack, closing, contextmanager from rlcompleter import Completer from types import CodeType from warnings import deprecated @@ -2670,12 +2672,21 @@ async def set_trace_async(*, header=None, commands=None): # Remote PDB class _PdbServer(Pdb): - def __init__(self, sockfile, owns_sockfile=True, **kwargs): + def __init__( + self, + sockfile, + signal_server=None, + owns_sockfile=True, + **kwargs, + ): self._owns_sockfile = owns_sockfile self._interact_state = None self._sockfile = sockfile self._command_name_cache = [] self._write_failed = False + if signal_server: + # Only started by the top level _PdbServer, not recursive ones. + self._start_signal_listener(signal_server) super().__init__(colorize=False, **kwargs) @staticmethod @@ -2731,15 +2742,49 @@ def _ensure_valid_message(self, msg): f"PDB message doesn't follow the schema! {msg}" ) + @classmethod + def _start_signal_listener(cls, address): + def listener(sock): + with closing(sock): + # Check if the interpreter is finalizing every quarter of a second. + # Clean up and exit if so. + sock.settimeout(0.25) + sock.shutdown(socket.SHUT_WR) + while not shut_down.is_set(): + try: + data = sock.recv(1024) + except socket.timeout: + continue + if data == b"": + return # EOF + signal.raise_signal(signal.SIGINT) + + def stop_thread(): + shut_down.set() + thread.join() + + # Use a daemon thread so that we don't detach until after all non-daemon + # threads are done. Use an atexit handler to stop gracefully at that point, + # so that our thread is stopped before the interpreter is torn down. + shut_down = threading.Event() + thread = threading.Thread( + target=listener, + args=[socket.create_connection(address, timeout=5)], + daemon=True, + ) + atexit.register(stop_thread) + thread.start() + def _send(self, **kwargs): self._ensure_valid_message(kwargs) json_payload = json.dumps(kwargs) try: self._sockfile.write(json_payload.encode() + b"\n") self._sockfile.flush() - except OSError: - # This means that the client has abruptly disconnected, but we'll - # handle that the next time we try to read from the client instead + except (OSError, ValueError): + # We get an OSError if the network connection has dropped, and a + # ValueError if detach() if the sockfile has been closed. We'll + # handle this the next time we try to read from the client instead # of trying to handle it from everywhere _send() may be called. # Track this with a flag rather than assuming readline() will ever # return an empty string because the socket may be half-closed. @@ -2967,10 +3012,15 @@ def default(self, line): class _PdbClient: - def __init__(self, pid, sockfile, interrupt_script): + def __init__(self, pid, server_socket, interrupt_sock): self.pid = pid - self.sockfile = sockfile - self.interrupt_script = interrupt_script + self.read_buf = b"" + self.signal_read = None + self.signal_write = None + self.sigint_received = False + self.raise_on_sigint = False + self.server_socket = server_socket + self.interrupt_sock = interrupt_sock self.pdb_instance = Pdb() self.pdb_commands = set() self.completion_matches = [] @@ -3012,8 +3062,7 @@ def _send(self, **kwargs): self._ensure_valid_message(kwargs) json_payload = json.dumps(kwargs) try: - self.sockfile.write(json_payload.encode() + b"\n") - self.sockfile.flush() + self.server_socket.sendall(json_payload.encode() + b"\n") except OSError: # This means that the client has abruptly disconnected, but we'll # handle that the next time we try to read from the client instead @@ -3022,10 +3071,44 @@ def _send(self, **kwargs): # return an empty string because the socket may be half-closed. self.write_failed = True - def read_command(self, prompt): - self.multiline_block = False - reply = input(prompt) + def _readline(self): + if self.sigint_received: + # There's a pending unhandled SIGINT. Handle it now. + self.sigint_received = False + raise KeyboardInterrupt + + # Wait for either a SIGINT or a line or EOF from the PDB server. + selector = selectors.DefaultSelector() + selector.register(self.signal_read, selectors.EVENT_READ) + selector.register(self.server_socket, selectors.EVENT_READ) + + while b"\n" not in self.read_buf: + for key, _ in selector.select(): + if key.fileobj == self.signal_read: + self.signal_read.recv(1024) + if self.sigint_received: + # If not, we're reading wakeup events for sigints that + # we've previously handled, and can ignore them. + self.sigint_received = False + raise KeyboardInterrupt + elif key.fileobj == self.server_socket: + data = self.server_socket.recv(16 * 1024) + self.read_buf += data + if not data and b"\n" not in self.read_buf: + # EOF without a full final line. Drop the partial line. + self.read_buf = b"" + return b"" + + ret, sep, self.read_buf = self.read_buf.partition(b"\n") + return ret + sep + + def read_input(self, prompt, multiline_block): + self.multiline_block = multiline_block + with self._sigint_raises_keyboard_interrupt(): + return input(prompt) + def read_command(self, prompt): + reply = self.read_input(prompt, multiline_block=False) if self.state == "dumb": # No logic applied whatsoever, just pass the raw reply back. return reply @@ -3048,10 +3131,9 @@ def read_command(self, prompt): return prefix + reply # Otherwise, valid first line of a multi-line statement - self.multiline_block = True - continue_prompt = "...".ljust(len(prompt)) + more_prompt = "...".ljust(len(prompt)) while codeop.compile_command(reply, "", "single") is None: - reply += "\n" + input(continue_prompt) + reply += "\n" + self.read_input(more_prompt, multiline_block=True) return prefix + reply @@ -3076,11 +3158,70 @@ def readline_completion(self, completer): finally: readline.set_completer(old_completer) + @contextmanager + def _sigint_handler(self): + # Signal handling strategy: + # - When we call input() we want a SIGINT to raise KeyboardInterrupt + # - Otherwise we want to write to the wakeup FD and set a flag. + # We'll break out of select() when the wakeup FD is written to, + # and we'll check the flag whenever we're about to accept input. + def handler(signum, frame): + self.sigint_received = True + if self.raise_on_sigint: + # One-shot; don't raise again until the flag is set again. + self.raise_on_sigint = False + self.sigint_received = False + raise KeyboardInterrupt + + sentinel = object() + old_handler = sentinel + old_wakeup_fd = sentinel + + self.signal_read, self.signal_write = socket.socketpair() + with (closing(self.signal_read), closing(self.signal_write)): + self.signal_read.setblocking(False) + self.signal_write.setblocking(False) + + try: + old_handler = signal.signal(signal.SIGINT, handler) + + try: + old_wakeup_fd = signal.set_wakeup_fd( + self.signal_write.fileno(), + warn_on_full_buffer=False, + ) + yield + finally: + # Restore the old wakeup fd if we installed a new one + if old_wakeup_fd is not sentinel: + signal.set_wakeup_fd(old_wakeup_fd) + finally: + self.signal_read = self.signal_write = None + if old_handler is not sentinel: + # Restore the old handler if we installed a new one + signal.signal(signal.SIGINT, old_handler) + + @contextmanager + def _sigint_raises_keyboard_interrupt(self): + if self.sigint_received: + # There's a pending unhandled SIGINT. Handle it now. + self.sigint_received = False + raise KeyboardInterrupt + + try: + self.raise_on_sigint = True + yield + finally: + self.raise_on_sigint = False + def cmdloop(self): - with self.readline_completion(self.complete): + with ( + self._sigint_handler(), + self.readline_completion(self.complete), + ): while not self.write_failed: try: - if not (payload_bytes := self.sockfile.readline()): + if not (payload_bytes := self._readline()): break except KeyboardInterrupt: self.send_interrupt() @@ -3098,11 +3239,17 @@ def cmdloop(self): self.process_payload(payload) def send_interrupt(self): - print( - "\n*** Program will stop at the next bytecode instruction." - " (Use 'cont' to resume)." - ) - sys.remote_exec(self.pid, self.interrupt_script) + if self.interrupt_sock is not None: + # Write to a socket that the PDB server listens on. This triggers + # the remote to raise a SIGINT for itself. We do this because + # Windows doesn't allow triggering SIGINT remotely. + # See https://stackoverflow.com/a/35792192 for many more details. + self.interrupt_sock.sendall(signal.SIGINT.to_bytes()) + else: + # On Unix we can just send a SIGINT to the remote process. + # This is preferable to using the signal thread approach that we + # use on Windows because it can interrupt IO in the main thread. + os.kill(self.pid, signal.SIGINT) def process_payload(self, payload): match payload: @@ -3172,7 +3319,7 @@ def complete(self, text, state): if self.write_failed: return None - payload = self.sockfile.readline() + payload = self._readline() if not payload: return None @@ -3189,11 +3336,18 @@ def complete(self, text, state): return None -def _connect(host, port, frame, commands, version): +def _connect(*, host, port, frame, commands, version, signal_raising_thread): with closing(socket.create_connection((host, port))) as conn: sockfile = conn.makefile("rwb") - remote_pdb = _PdbServer(sockfile) + # The client requests this thread on Windows but not on Unix. + # Most tests don't request this thread, to keep them simpler. + if signal_raising_thread: + signal_server = (host, port) + else: + signal_server = None + + remote_pdb = _PdbServer(sockfile, signal_server=signal_server) weakref.finalize(remote_pdb, sockfile.close) if Pdb._last_pdb_instance is not None: @@ -3214,43 +3368,48 @@ def _connect(host, port, frame, commands, version): def attach(pid, commands=()): """Attach to a running process with the given PID.""" - with closing(socket.create_server(("localhost", 0))) as server: + with ExitStack() as stack: + server = stack.enter_context( + closing(socket.create_server(("localhost", 0))) + ) port = server.getsockname()[1] - with tempfile.NamedTemporaryFile("w", delete_on_close=False) as connect_script: - connect_script.write( - textwrap.dedent( - f""" - import pdb, sys - pdb._connect( - host="localhost", - port={port}, - frame=sys._getframe(1), - commands={json.dumps("\n".join(commands))}, - version={_PdbServer.protocol_version()}, - ) - """ + connect_script = stack.enter_context( + tempfile.NamedTemporaryFile("w", delete_on_close=False) + ) + + use_signal_thread = sys.platform == "win32" + + connect_script.write( + textwrap.dedent( + f""" + import pdb, sys + pdb._connect( + host="localhost", + port={port}, + frame=sys._getframe(1), + commands={json.dumps("\n".join(commands))}, + version={_PdbServer.protocol_version()}, + signal_raising_thread={use_signal_thread!r}, ) + """ ) - connect_script.close() - sys.remote_exec(pid, connect_script.name) - - # TODO Add a timeout? Or don't bother since the user can ^C? - client_sock, _ = server.accept() + ) + connect_script.close() + sys.remote_exec(pid, connect_script.name) - with closing(client_sock): - sockfile = client_sock.makefile("rwb") + # TODO Add a timeout? Or don't bother since the user can ^C? + client_sock, _ = server.accept() + stack.enter_context(closing(client_sock)) - with closing(sockfile): - with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script: - interrupt_script.write( - 'import pdb, sys\n' - 'if inst := pdb.Pdb._last_pdb_instance:\n' - ' inst.set_trace(sys._getframe(1))\n' - ) - interrupt_script.close() + if use_signal_thread: + interrupt_sock, _ = server.accept() + stack.enter_context(closing(interrupt_sock)) + interrupt_sock.setblocking(False) + else: + interrupt_sock = None - _PdbClient(pid, sockfile, interrupt_script.name).cmdloop() + _PdbClient(pid, client_sock, interrupt_sock).cmdloop() # Post-Mortem interface diff --git a/Lib/test/test_remote_pdb.py b/Lib/test/test_remote_pdb.py index 9fbe94fcdd6da7..9c794991dd5ed9 100644 --- a/Lib/test/test_remote_pdb.py +++ b/Lib/test/test_remote_pdb.py @@ -12,7 +12,7 @@ import threading import unittest import unittest.mock -from contextlib import contextmanager, redirect_stdout, ExitStack +from contextlib import closing, contextmanager, redirect_stdout, ExitStack from pathlib import Path from test.support import is_wasi, os_helper, requires_subprocess, SHORT_TIMEOUT from test.support.os_helper import temp_dir, TESTFN, unlink @@ -79,44 +79,6 @@ def get_output(self) -> List[dict]: return results -class MockDebuggerSocket: - """Mock file-like simulating a connection to a _RemotePdb instance""" - - def __init__(self, incoming): - self.incoming = iter(incoming) - self.outgoing = [] - self.buffered = bytearray() - - def write(self, data: bytes) -> None: - """Simulate write to socket.""" - self.buffered += data - - def flush(self) -> None: - """Ensure each line is valid JSON.""" - lines = self.buffered.splitlines(keepends=True) - self.buffered.clear() - for line in lines: - assert line.endswith(b"\n") - self.outgoing.append(json.loads(line)) - - def readline(self) -> bytes: - """Read a line from the prepared input queue.""" - # Anything written must be flushed before trying to read, - # since the read will be dependent upon the last write. - assert not self.buffered - try: - item = next(self.incoming) - if not isinstance(item, bytes): - item = json.dumps(item).encode() - return item + b"\n" - except StopIteration: - return b"" - - def close(self) -> None: - """No-op close implementation.""" - pass - - class PdbClientTestCase(unittest.TestCase): """Tests for the _PdbClient class.""" @@ -124,8 +86,11 @@ def do_test( self, *, incoming, - simulate_failure=None, + simulate_send_failure=False, + simulate_sigint_during_stdout_write=False, + use_interrupt_socket=False, expected_outgoing=None, + expected_outgoing_signals=None, expected_completions=None, expected_exception=None, expected_stdout="", @@ -134,6 +99,8 @@ def do_test( ): if expected_outgoing is None: expected_outgoing = [] + if expected_outgoing_signals is None: + expected_outgoing_signals = [] if expected_completions is None: expected_completions = [] if expected_state is None: @@ -142,16 +109,6 @@ def do_test( expected_state.setdefault("write_failed", False) messages = [m for source, m in incoming if source == "server"] prompts = [m["prompt"] for source, m in incoming if source == "user"] - sockfile = MockDebuggerSocket(messages) - stdout = io.StringIO() - - if simulate_failure: - sockfile.write = unittest.mock.Mock() - sockfile.flush = unittest.mock.Mock() - if simulate_failure == "write": - sockfile.write.side_effect = OSError("write failed") - elif simulate_failure == "flush": - sockfile.flush.side_effect = OSError("flush failed") input_iter = (m for source, m in incoming if source == "user") completions = [] @@ -178,18 +135,60 @@ def mock_input(prompt): reply = message["input"] if isinstance(reply, BaseException): raise reply - return reply + if isinstance(reply, str): + return reply + return reply() with ExitStack() as stack: + client_sock, server_sock = socket.socketpair() + stack.enter_context(closing(client_sock)) + stack.enter_context(closing(server_sock)) + + server_sock = unittest.mock.Mock(wraps=server_sock) + + client_sock.sendall( + b"".join( + (m if isinstance(m, bytes) else json.dumps(m).encode()) + b"\n" + for m in messages + ) + ) + client_sock.shutdown(socket.SHUT_WR) + + if simulate_send_failure: + server_sock.sendall = unittest.mock.Mock( + side_effect=OSError("sendall failed") + ) + client_sock.shutdown(socket.SHUT_RD) + + stdout = io.StringIO() + + if simulate_sigint_during_stdout_write: + orig_stdout_write = stdout.write + + def sigint_stdout_write(s): + signal.raise_signal(signal.SIGINT) + return orig_stdout_write(s) + + stdout.write = sigint_stdout_write + input_mock = stack.enter_context( unittest.mock.patch("pdb.input", side_effect=mock_input) ) stack.enter_context(redirect_stdout(stdout)) + if use_interrupt_socket: + interrupt_sock = unittest.mock.Mock(spec=socket.socket) + mock_kill = None + else: + interrupt_sock = None + mock_kill = stack.enter_context( + unittest.mock.patch("os.kill", spec=os.kill) + ) + client = _PdbClient( - pid=0, - sockfile=sockfile, - interrupt_script="/a/b.py", + pid=12345, + server_socket=server_sock, + interrupt_sock=interrupt_sock, ) if expected_exception is not None: @@ -199,13 +198,12 @@ def mock_input(prompt): client.cmdloop() - actual_outgoing = sockfile.outgoing - if simulate_failure: - actual_outgoing += [ - json.loads(msg.args[0]) for msg in sockfile.write.mock_calls - ] + sent_msgs = [msg.args[0] for msg in server_sock.sendall.mock_calls] + for msg in sent_msgs: + assert msg.endswith(b"\n") + actual_outgoing = [json.loads(msg) for msg in sent_msgs] - self.assertEqual(sockfile.outgoing, expected_outgoing) + self.assertEqual(actual_outgoing, expected_outgoing) self.assertEqual(completions, expected_completions) if expected_stdout_substring and not expected_stdout: self.assertIn(expected_stdout_substring, stdout.getvalue()) @@ -215,6 +213,20 @@ def mock_input(prompt): actual_state = {k: getattr(client, k) for k in expected_state} self.assertEqual(actual_state, expected_state) + if use_interrupt_socket: + outgoing_signals = [ + signal.Signals(int.from_bytes(call.args[0])) + for call in interrupt_sock.sendall.call_args_list + ] + else: + assert mock_kill is not None + outgoing_signals = [] + for call in mock_kill.call_args_list: + pid, signum = call.args + self.assertEqual(pid, 12345) + outgoing_signals.append(signal.Signals(signum)) + self.assertEqual(outgoing_signals, expected_outgoing_signals) + def test_remote_immediately_closing_the_connection(self): """Test the behavior when the remote closes the connection immediately.""" incoming = [] @@ -409,11 +421,38 @@ def test_handling_unrecognized_prompt_type(self): expected_state={"state": "dumb"}, ) - def test_keyboard_interrupt_at_prompt(self): - """Test signaling when a prompt gets a KeyboardInterrupt.""" + def test_sigint_at_prompt(self): + """Test signaling when a prompt gets interrupted.""" incoming = [ ("server", {"prompt": "(Pdb) ", "state": "pdb"}), - ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}), + ( + "user", + { + "prompt": "(Pdb) ", + "input": lambda: signal.raise_signal(signal.SIGINT), + }, + ), + ] + self.do_test( + incoming=incoming, + expected_outgoing=[ + {"signal": "INT"}, + ], + expected_state={"state": "pdb"}, + ) + + def test_sigint_at_continuation_prompt(self): + """Test signaling when a continuation prompt gets interrupted.""" + incoming = [ + ("server", {"prompt": "(Pdb) ", "state": "pdb"}), + ("user", {"prompt": "(Pdb) ", "input": "if True:"}), + ( + "user", + { + "prompt": "... ", + "input": lambda: signal.raise_signal(signal.SIGINT), + }, + ), ] self.do_test( incoming=incoming, @@ -423,6 +462,22 @@ def test_keyboard_interrupt_at_prompt(self): expected_state={"state": "pdb"}, ) + def test_sigint_when_writing(self): + """Test siginaling when sys.stdout.write() gets interrupted.""" + incoming = [ + ("server", {"message": "Some message or other\n", "type": "info"}), + ] + for use_interrupt_socket in [False, True]: + with self.subTest(use_interrupt_socket=use_interrupt_socket): + self.do_test( + incoming=incoming, + simulate_sigint_during_stdout_write=True, + use_interrupt_socket=use_interrupt_socket, + expected_outgoing=[], + expected_outgoing_signals=[signal.SIGINT], + expected_stdout="Some message or other\n", + ) + def test_eof_at_prompt(self): """Test signaling when a prompt gets an EOFError.""" incoming = [ @@ -478,20 +533,7 @@ def test_write_failing(self): self.do_test( incoming=incoming, expected_outgoing=[{"signal": "INT"}], - simulate_failure="write", - expected_state={"write_failed": True}, - ) - - def test_flush_failing(self): - """Test terminating if flush fails due to a half closed socket.""" - incoming = [ - ("server", {"prompt": "(Pdb) ", "state": "pdb"}), - ("user", {"prompt": "(Pdb) ", "input": KeyboardInterrupt()}), - ] - self.do_test( - incoming=incoming, - expected_outgoing=[{"signal": "INT"}], - simulate_failure="flush", + simulate_send_failure=True, expected_state={"write_failed": True}, ) @@ -660,42 +702,7 @@ def test_write_failure_during_completion(self): }, {"reply": "xyz"}, ], - simulate_failure="write", - expected_completions=[], - expected_state={"state": "interact", "write_failed": True}, - ) - - def test_flush_failure_during_completion(self): - """Test failing to flush to the socket to request tab completions.""" - incoming = [ - ("server", {"prompt": ">>> ", "state": "interact"}), - ( - "user", - { - "prompt": ">>> ", - "completion_request": { - "line": "xy", - "begidx": 0, - "endidx": 2, - }, - "input": "xyz", - }, - ), - ] - self.do_test( - incoming=incoming, - expected_outgoing=[ - { - "complete": { - "text": "xy", - "line": "xy", - "begidx": 0, - "endidx": 2, - } - }, - {"reply": "xyz"}, - ], - simulate_failure="flush", + simulate_send_failure=True, expected_completions=[], expected_state={"state": "interact", "write_failed": True}, ) @@ -1032,6 +1039,7 @@ def dummy_function(): frame=frame, commands="", version=pdb._PdbServer.protocol_version(), + signal_raising_thread=False, ) return x # This line won't be reached in debugging @@ -1089,23 +1097,6 @@ def _send_command(self, client_file, command): client_file.write(json.dumps({"reply": command}).encode() + b"\n") client_file.flush() - def _send_interrupt(self, pid): - """Helper to send an interrupt signal to the debugger.""" - # with tempfile.NamedTemporaryFile("w", delete_on_close=False) as interrupt_script: - interrupt_script = TESTFN + "_interrupt_script.py" - with open(interrupt_script, 'w') as f: - f.write( - 'import pdb, sys\n' - 'print("Hello, world!")\n' - 'if inst := pdb.Pdb._last_pdb_instance:\n' - ' inst.set_trace(sys._getframe(1))\n' - ) - self.addCleanup(unlink, interrupt_script) - try: - sys.remote_exec(pid, interrupt_script) - except PermissionError: - self.skipTest("Insufficient permissions to execute code in remote process") - def test_connect_and_basic_commands(self): """Test connecting to a remote debugger and sending basic commands.""" self._create_script() @@ -1218,6 +1209,7 @@ def bar(): frame=frame, commands="", version=pdb._PdbServer.protocol_version(), + signal_raising_thread=True, ) print("Connected to debugger") iterations = 50 @@ -1233,6 +1225,10 @@ def bar(): self._create_script(script=script) process, client_file = self._connect_and_get_client_file() + # Accept a 2nd connection from the subprocess to tell it about signals + signal_sock, _ = self.server_sock.accept() + self.addCleanup(signal_sock.close) + with kill_on_error(process): # Skip initial messages until we get to the prompt self._read_until_prompt(client_file) @@ -1248,7 +1244,7 @@ def bar(): break # Inject a script to interrupt the running process - self._send_interrupt(process.pid) + signal_sock.sendall(signal.SIGINT.to_bytes()) messages = self._read_until_prompt(client_file) # Verify we got the keyboard interrupt message. @@ -1304,6 +1300,7 @@ def run_test(): frame=frame, commands="", version=fake_version, + signal_raising_thread=False, ) # This should print if the debugger detaches correctly diff --git a/Misc/NEWS.d/next/Library/2025-05-01-18-32-44.gh-issue-133223.KE_T5f.rst b/Misc/NEWS.d/next/Library/2025-05-01-18-32-44.gh-issue-133223.KE_T5f.rst new file mode 100644 index 00000000000000..41d2f87a79056b --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-05-01-18-32-44.gh-issue-133223.KE_T5f.rst @@ -0,0 +1,2 @@ +When PDB is attached to a remote process, do a better job of intercepting +Ctrl+C and forwarding it to the remote process.