Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream dmypy output instead of dumping everything at the end #16252

Merged
merged 4 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions mypy/dmypy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Callable, Mapping, NoReturn

from mypy.dmypy_os import alive, kill
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive, send
from mypy.ipc import IPCClient, IPCException
from mypy.util import check_python_version, get_terminal_width, should_force_color
from mypy.version import __version__
Expand Down Expand Up @@ -659,28 +659,29 @@ def request(
# so that it can format the type checking output accordingly.
args["is_tty"] = sys.stdout.isatty() or should_force_color()
args["terminal_width"] = get_terminal_width()
bdata = json.dumps(args).encode("utf8")
_, name = get_status(status_file)
try:
with IPCClient(name, timeout) as client:
client.write(bdata)
response = receive(client)
send(client, args)

final = False
while not final:
response = receive(client)
final = bool(response.pop("final", False))
# Display debugging output written to stdout/stderr in the server process for convenience.
# This should not be confused with "out" and "err" fields in the response.
# Those fields hold the output of the "check" command, and are handled in check_output().
stdout = response.pop("stdout", None)
if stdout:
sys.stdout.write(stdout)
stderr = response.pop("stderr", None)
if stderr:
sys.stderr.write(stderr)
except (OSError, IPCException) as err:
return {"error": str(err)}
# TODO: Other errors, e.g. ValueError, UnicodeError
else:
# Display debugging output written to stdout/stderr in the server process for convenience.
# This should not be confused with "out" and "err" fields in the response.
# Those fields hold the output of the "check" command, and are handled in check_output().
stdout = response.get("stdout")
if stdout:
sys.stdout.write(stdout)
stderr = response.get("stderr")
if stderr:
print("-" * 79)
print("stderr:")
sys.stdout.write(stderr)
return response

return response


def get_status(status_file: str) -> tuple[int, str]:
Expand Down
20 changes: 9 additions & 11 deletions mypy/dmypy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import mypy.build
import mypy.errors
import mypy.main
from mypy.dmypy_util import receive
from mypy.dmypy_util import WriteToConn, receive, send
from mypy.find_sources import InvalidSourceList, create_source_list
from mypy.fscache import FileSystemCache
from mypy.fswatcher import FileData, FileSystemWatcher
Expand Down Expand Up @@ -208,21 +208,21 @@ def _response_metadata(self) -> dict[str, str]:

def serve(self) -> None:
"""Serve requests, synchronously (no thread or fork)."""

command = None
server = IPCServer(CONNECTION_NAME, self.timeout)
orig_stdout = sys.stdout
orig_stderr = sys.stderr

try:
with open(self.status_file, "w") as f:
json.dump({"pid": os.getpid(), "connection_name": server.connection_name}, f)
f.write("\n") # I like my JSON with a trailing newline
while True:
with server:
data = receive(server)
debug_stdout = io.StringIO()
debug_stderr = io.StringIO()
sys.stdout = debug_stdout
sys.stderr = debug_stderr
sys.stdout = WriteToConn(server, "stdout") # type: ignore[assignment]
sys.stderr = WriteToConn(server, "stderr") # type: ignore[assignment]
resp: dict[str, Any] = {}
if "command" not in data:
resp = {"error": "No command found in request"}
Expand All @@ -239,15 +239,13 @@ def serve(self) -> None:
tb = traceback.format_exception(*sys.exc_info())
resp = {"error": "Daemon crashed!\n" + "".join(tb)}
resp.update(self._response_metadata())
resp["stdout"] = debug_stdout.getvalue()
resp["stderr"] = debug_stderr.getvalue()
server.write(json.dumps(resp).encode("utf8"))
resp["final"] = True
send(server, resp)
raise
resp["stdout"] = debug_stdout.getvalue()
resp["stderr"] = debug_stderr.getvalue()
resp["final"] = True
try:
resp.update(self._response_metadata())
server.write(json.dumps(resp).encode("utf8"))
send(server, resp)
except OSError:
pass # Maybe the client hung up
if command == "stop":
Expand Down
33 changes: 30 additions & 3 deletions mypy/dmypy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from __future__ import annotations

import json
from typing import Any, Final
from typing import Any, Final, Iterable

from mypy.ipc import IPCBase

DEFAULT_STATUS_FILE: Final = ".dmypy.json"


def receive(connection: IPCBase) -> Any:
"""Receive JSON data from a connection until EOF.
"""Receive single JSON data frame from a connection.

Raise OSError if the data received is not valid JSON or if it is
not a dict.
Expand All @@ -23,9 +23,36 @@ def receive(connection: IPCBase) -> Any:
if not bdata:
raise OSError("No data received")
try:
data = json.loads(bdata.decode("utf8"))
data = json.loads(bdata)
except Exception as e:
raise OSError("Data received is not valid JSON") from e
if not isinstance(data, dict):
raise OSError(f"Data received is not a dict ({type(data)})")
return data


def send(connection: IPCBase, data: Any) -> None:
"""Send data to a connection encoded and framed.

The data must be JSON-serializable. We assume that a single send call is a
single frame to be sent on the connect.
"""
connection.write(json.dumps(data))


class WriteToConn:
"""Helper class to write to a connection instead of standard output."""

def __init__(self, server: IPCBase, output_key: str = "stdout"):
self.server = server
self.output_key = output_key

def write(self, output: str) -> int:
resp: dict[str, Any] = {}
resp[self.output_key] = output
send(self.server, resp)
return len(output)

def writelines(self, lines: Iterable[str]) -> None:
for s in lines:
self.write(s)
66 changes: 54 additions & 12 deletions mypy/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import base64
import codecs
import os
import shutil
import sys
Expand Down Expand Up @@ -40,19 +41,41 @@ class IPCBase:

This contains logic shared between the client and server, such as reading
and writing.
We want to be able to send multiple "messages" over a single connection and
to be able to separate the messages. We do this by encoding the messages
in an alphabet that does not contain spaces, then adding a space for
separation. The last framed message is also followed by a space.
"""

connection: _IPCHandle

def __init__(self, name: str, timeout: float | None) -> None:
self.name = name
self.timeout = timeout
self.buffer = bytearray()

def read(self, size: int = 100000) -> bytes:
"""Read bytes from an IPC connection until its empty."""
bdata = bytearray()
def frame_from_buffer(self) -> bytearray | None:
"""Return a full frame from the bytes we have in the buffer."""
space_pos = self.buffer.find(b" ")
if space_pos == -1:
return None
# We have a full frame
bdata = self.buffer[:space_pos]
self.buffer = self.buffer[space_pos + 1 :]
return bdata

def read(self, size: int = 100000) -> str:
"""Read bytes from an IPC connection until we have a full frame."""
bdata: bytearray | None = bytearray()
if sys.platform == "win32":
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break

# Receive more data into the buffer.
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
Expand All @@ -66,7 +89,10 @@ def read(self, size: int = 100000) -> bytes:
_, err = ov.GetOverlappedResult(True)
more = ov.getbuffer()
if more:
bdata.extend(more)
self.buffer.extend(more)
bdata = self.frame_from_buffer()
if bdata is not None:
break
if err == 0:
# we are done!
break
Expand All @@ -77,17 +103,34 @@ def read(self, size: int = 100000) -> bytes:
raise IPCException("ReadFile operation aborted.")
else:
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break

# Receive more data into the buffer.
more = self.connection.recv(size)
if not more:
# Connection closed
break
bdata.extend(more)
return bytes(bdata)
self.buffer.extend(more)

if not bdata:
# Socket was empty and we didn't get any frame.
# This should only happen if the socket was closed.
return ""
return codecs.decode(bdata, "base64").decode("utf8")

def write(self, data: str) -> None:
"""Write to an IPC connection."""

# Frame the data by urlencoding it and separating by space.
encoded_data = codecs.encode(data.encode("utf8"), "base64") + b" "

def write(self, data: bytes) -> None:
"""Write bytes to an IPC connection."""
if sys.platform == "win32":
try:
ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
ov, err = _winapi.WriteFile(self.connection, encoded_data, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
Expand All @@ -101,12 +144,11 @@ def write(self, data: bytes) -> None:
raise
bytes_written, err = ov.GetOverlappedResult(True)
assert err == 0, err
assert bytes_written == len(data)
assert bytes_written == len(encoded_data)
except OSError as e:
raise IPCException(f"Failed to write with error: {e.winerror}") from e
else:
self.connection.sendall(data)
self.connection.shutdown(socket.SHUT_WR)
self.connection.sendall(encoded_data)

def close(self) -> None:
if sys.platform == "win32":
Expand Down
52 changes: 44 additions & 8 deletions mypy/test/testipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,25 @@
def server(msg: str, q: Queue[str]) -> None:
server = IPCServer(CONNECTION_NAME)
q.put(server.connection_name)
data = b""
data = ""
while not data:
with server:
server.write(msg.encode())
server.write(msg)
data = server.read()
server.cleanup()


def server_multi_message_echo(q: Queue[str]) -> None:
server = IPCServer(CONNECTION_NAME)
q.put(server.connection_name)
data = ""
with server:
while data != "quit":
data = server.read()
server.write(data)
server.cleanup()


class IPCTests(TestCase):
def test_transaction_large(self) -> None:
queue: Queue[str] = Queue()
Expand All @@ -31,8 +42,8 @@ def test_transaction_large(self) -> None:
p.start()
connection_name = queue.get()
with IPCClient(connection_name, timeout=1) as client:
assert client.read() == msg.encode()
client.write(b"test")
assert client.read() == msg
client.write("test")
queue.close()
queue.join_thread()
p.join()
Expand All @@ -44,12 +55,37 @@ def test_connect_twice(self) -> None:
p.start()
connection_name = queue.get()
with IPCClient(connection_name, timeout=1) as client:
assert client.read() == msg.encode()
client.write(b"") # don't let the server hang up yet, we want to connect again.
assert client.read() == msg
client.write("") # don't let the server hang up yet, we want to connect again.

with IPCClient(connection_name, timeout=1) as client:
assert client.read() == msg.encode()
client.write(b"test")
assert client.read() == msg
client.write("test")
queue.close()
queue.join_thread()
p.join()
assert p.exitcode == 0

def test_multiple_messages(self) -> None:
queue: Queue[str] = Queue()
p = Process(target=server_multi_message_echo, args=(queue,), daemon=True)
p.start()
connection_name = queue.get()
with IPCClient(connection_name, timeout=1) as client:
# "foo bar" with extra accents on letters.
# In UTF-8 encoding so we don't confuse editors opening this file.
fancy_text = b"f\xcc\xb6o\xcc\xb2\xf0\x9d\x91\x9c \xd0\xb2\xe2\xb7\xa1a\xcc\xb6r\xcc\x93\xcd\x98\xcd\x8c"
client.write(fancy_text.decode("utf-8"))
assert client.read() == fancy_text.decode("utf-8")

client.write("Test with spaces")
client.write("Test write before reading previous")
time.sleep(0) # yield to the server to force reading of all messages by server.
assert client.read() == "Test with spaces"
assert client.read() == "Test write before reading previous"

client.write("quit")
assert client.read() == "quit"
queue.close()
queue.join_thread()
p.join()
Expand Down