Skip to content

Commit

Permalink
Stream dmypy output instead of dumping everything at the end (#16252)
Browse files Browse the repository at this point in the history
This does 2 things:
1. It changes the IPC code to work with multiple messages.
2. It changes the dmypy client/server communication so that it streams
stdout/stderr instead of dumping everything at the end.

For 1, we have to provide a way to separate out different messages. I
chose to frame messages as bytes separated by whitespace character. That
means we have to encode the message in a scheme that escapes whitespace.
The `codecs.encode(<bytes_data>, 'base64')` seems reasonable. It encodes more
than needed but the application is not IPC IO limited so it should be fine.
With this convention in place, all we have to do is read from the socket
stream until we have a whitespace character.
The framing logic can be easily changed.

For 2, since we communicate with JSONs, it's easy to add a "finished"
key that tells us it's the final response from dmypy. Anything else is
just stdout/stderr output.

Note: dmypy server also returns out/err which is the output of actual
mypy type checking. Right now this change does not stream that output.
We can stream that in a followup change. We just have to decide on how
to differenciate the 4 text streams (stdout/stderr/out/err) that will
now be interleaved.

The WriteToConn class could use more love. I just put a bare minimum.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
svalentin and pre-commit-ci[bot] authored Oct 16, 2023
1 parent e435594 commit 2bcec24
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 51 deletions.
35 changes: 18 additions & 17 deletions mypy/dmypy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Callable, Mapping, NoReturn

from mypy.dmypy_os import alive, kill
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive, send
from mypy.ipc import IPCClient, IPCException
from mypy.util import check_python_version, get_terminal_width, should_force_color
from mypy.version import __version__
Expand Down Expand Up @@ -659,28 +659,29 @@ def request(
# so that it can format the type checking output accordingly.
args["is_tty"] = sys.stdout.isatty() or should_force_color()
args["terminal_width"] = get_terminal_width()
bdata = json.dumps(args).encode("utf8")
_, name = get_status(status_file)
try:
with IPCClient(name, timeout) as client:
client.write(bdata)
response = receive(client)
send(client, args)

final = False
while not final:
response = receive(client)
final = bool(response.pop("final", False))
# Display debugging output written to stdout/stderr in the server process for convenience.
# This should not be confused with "out" and "err" fields in the response.
# Those fields hold the output of the "check" command, and are handled in check_output().
stdout = response.pop("stdout", None)
if stdout:
sys.stdout.write(stdout)
stderr = response.pop("stderr", None)
if stderr:
sys.stderr.write(stderr)
except (OSError, IPCException) as err:
return {"error": str(err)}
# TODO: Other errors, e.g. ValueError, UnicodeError
else:
# Display debugging output written to stdout/stderr in the server process for convenience.
# This should not be confused with "out" and "err" fields in the response.
# Those fields hold the output of the "check" command, and are handled in check_output().
stdout = response.get("stdout")
if stdout:
sys.stdout.write(stdout)
stderr = response.get("stderr")
if stderr:
print("-" * 79)
print("stderr:")
sys.stdout.write(stderr)
return response

return response


def get_status(status_file: str) -> tuple[int, str]:
Expand Down
20 changes: 9 additions & 11 deletions mypy/dmypy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import mypy.build
import mypy.errors
import mypy.main
from mypy.dmypy_util import receive
from mypy.dmypy_util import WriteToConn, receive, send
from mypy.find_sources import InvalidSourceList, create_source_list
from mypy.fscache import FileSystemCache
from mypy.fswatcher import FileData, FileSystemWatcher
Expand Down Expand Up @@ -208,21 +208,21 @@ def _response_metadata(self) -> dict[str, str]:

def serve(self) -> None:
"""Serve requests, synchronously (no thread or fork)."""

command = None
server = IPCServer(CONNECTION_NAME, self.timeout)
orig_stdout = sys.stdout
orig_stderr = sys.stderr

try:
with open(self.status_file, "w") as f:
json.dump({"pid": os.getpid(), "connection_name": server.connection_name}, f)
f.write("\n") # I like my JSON with a trailing newline
while True:
with server:
data = receive(server)
debug_stdout = io.StringIO()
debug_stderr = io.StringIO()
sys.stdout = debug_stdout
sys.stderr = debug_stderr
sys.stdout = WriteToConn(server, "stdout") # type: ignore[assignment]
sys.stderr = WriteToConn(server, "stderr") # type: ignore[assignment]
resp: dict[str, Any] = {}
if "command" not in data:
resp = {"error": "No command found in request"}
Expand All @@ -239,15 +239,13 @@ def serve(self) -> None:
tb = traceback.format_exception(*sys.exc_info())
resp = {"error": "Daemon crashed!\n" + "".join(tb)}
resp.update(self._response_metadata())
resp["stdout"] = debug_stdout.getvalue()
resp["stderr"] = debug_stderr.getvalue()
server.write(json.dumps(resp).encode("utf8"))
resp["final"] = True
send(server, resp)
raise
resp["stdout"] = debug_stdout.getvalue()
resp["stderr"] = debug_stderr.getvalue()
resp["final"] = True
try:
resp.update(self._response_metadata())
server.write(json.dumps(resp).encode("utf8"))
send(server, resp)
except OSError:
pass # Maybe the client hung up
if command == "stop":
Expand Down
33 changes: 30 additions & 3 deletions mypy/dmypy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from __future__ import annotations

import json
from typing import Any, Final
from typing import Any, Final, Iterable

from mypy.ipc import IPCBase

DEFAULT_STATUS_FILE: Final = ".dmypy.json"


def receive(connection: IPCBase) -> Any:
"""Receive JSON data from a connection until EOF.
"""Receive single JSON data frame from a connection.
Raise OSError if the data received is not valid JSON or if it is
not a dict.
Expand All @@ -23,9 +23,36 @@ def receive(connection: IPCBase) -> Any:
if not bdata:
raise OSError("No data received")
try:
data = json.loads(bdata.decode("utf8"))
data = json.loads(bdata)
except Exception as e:
raise OSError("Data received is not valid JSON") from e
if not isinstance(data, dict):
raise OSError(f"Data received is not a dict ({type(data)})")
return data


def send(connection: IPCBase, data: Any) -> None:
"""Send data to a connection encoded and framed.
The data must be JSON-serializable. We assume that a single send call is a
single frame to be sent on the connect.
"""
connection.write(json.dumps(data))


class WriteToConn:
"""Helper class to write to a connection instead of standard output."""

def __init__(self, server: IPCBase, output_key: str = "stdout"):
self.server = server
self.output_key = output_key

def write(self, output: str) -> int:
resp: dict[str, Any] = {}
resp[self.output_key] = output
send(self.server, resp)
return len(output)

def writelines(self, lines: Iterable[str]) -> None:
for s in lines:
self.write(s)
66 changes: 54 additions & 12 deletions mypy/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import base64
import codecs
import os
import shutil
import sys
Expand Down Expand Up @@ -40,19 +41,41 @@ class IPCBase:
This contains logic shared between the client and server, such as reading
and writing.
We want to be able to send multiple "messages" over a single connection and
to be able to separate the messages. We do this by encoding the messages
in an alphabet that does not contain spaces, then adding a space for
separation. The last framed message is also followed by a space.
"""

connection: _IPCHandle

def __init__(self, name: str, timeout: float | None) -> None:
self.name = name
self.timeout = timeout
self.buffer = bytearray()

def read(self, size: int = 100000) -> bytes:
"""Read bytes from an IPC connection until its empty."""
bdata = bytearray()
def frame_from_buffer(self) -> bytearray | None:
"""Return a full frame from the bytes we have in the buffer."""
space_pos = self.buffer.find(b" ")
if space_pos == -1:
return None
# We have a full frame
bdata = self.buffer[:space_pos]
self.buffer = self.buffer[space_pos + 1 :]
return bdata

def read(self, size: int = 100000) -> str:
"""Read bytes from an IPC connection until we have a full frame."""
bdata: bytearray | None = bytearray()
if sys.platform == "win32":
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break

# Receive more data into the buffer.
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
Expand All @@ -66,7 +89,10 @@ def read(self, size: int = 100000) -> bytes:
_, err = ov.GetOverlappedResult(True)
more = ov.getbuffer()
if more:
bdata.extend(more)
self.buffer.extend(more)
bdata = self.frame_from_buffer()
if bdata is not None:
break
if err == 0:
# we are done!
break
Expand All @@ -77,17 +103,34 @@ def read(self, size: int = 100000) -> bytes:
raise IPCException("ReadFile operation aborted.")
else:
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break

# Receive more data into the buffer.
more = self.connection.recv(size)
if not more:
# Connection closed
break
bdata.extend(more)
return bytes(bdata)
self.buffer.extend(more)

if not bdata:
# Socket was empty and we didn't get any frame.
# This should only happen if the socket was closed.
return ""
return codecs.decode(bdata, "base64").decode("utf8")

def write(self, data: str) -> None:
"""Write to an IPC connection."""

# Frame the data by urlencoding it and separating by space.
encoded_data = codecs.encode(data.encode("utf8"), "base64") + b" "

def write(self, data: bytes) -> None:
"""Write bytes to an IPC connection."""
if sys.platform == "win32":
try:
ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
ov, err = _winapi.WriteFile(self.connection, encoded_data, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
Expand All @@ -101,12 +144,11 @@ def write(self, data: bytes) -> None:
raise
bytes_written, err = ov.GetOverlappedResult(True)
assert err == 0, err
assert bytes_written == len(data)
assert bytes_written == len(encoded_data)
except OSError as e:
raise IPCException(f"Failed to write with error: {e.winerror}") from e
else:
self.connection.sendall(data)
self.connection.shutdown(socket.SHUT_WR)
self.connection.sendall(encoded_data)

def close(self) -> None:
if sys.platform == "win32":
Expand Down
52 changes: 44 additions & 8 deletions mypy/test/testipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,25 @@
def server(msg: str, q: Queue[str]) -> None:
server = IPCServer(CONNECTION_NAME)
q.put(server.connection_name)
data = b""
data = ""
while not data:
with server:
server.write(msg.encode())
server.write(msg)
data = server.read()
server.cleanup()


def server_multi_message_echo(q: Queue[str]) -> None:
server = IPCServer(CONNECTION_NAME)
q.put(server.connection_name)
data = ""
with server:
while data != "quit":
data = server.read()
server.write(data)
server.cleanup()


class IPCTests(TestCase):
def test_transaction_large(self) -> None:
queue: Queue[str] = Queue()
Expand All @@ -31,8 +42,8 @@ def test_transaction_large(self) -> None:
p.start()
connection_name = queue.get()
with IPCClient(connection_name, timeout=1) as client:
assert client.read() == msg.encode()
client.write(b"test")
assert client.read() == msg
client.write("test")
queue.close()
queue.join_thread()
p.join()
Expand All @@ -44,12 +55,37 @@ def test_connect_twice(self) -> None:
p.start()
connection_name = queue.get()
with IPCClient(connection_name, timeout=1) as client:
assert client.read() == msg.encode()
client.write(b"") # don't let the server hang up yet, we want to connect again.
assert client.read() == msg
client.write("") # don't let the server hang up yet, we want to connect again.

with IPCClient(connection_name, timeout=1) as client:
assert client.read() == msg.encode()
client.write(b"test")
assert client.read() == msg
client.write("test")
queue.close()
queue.join_thread()
p.join()
assert p.exitcode == 0

def test_multiple_messages(self) -> None:
queue: Queue[str] = Queue()
p = Process(target=server_multi_message_echo, args=(queue,), daemon=True)
p.start()
connection_name = queue.get()
with IPCClient(connection_name, timeout=1) as client:
# "foo bar" with extra accents on letters.
# In UTF-8 encoding so we don't confuse editors opening this file.
fancy_text = b"f\xcc\xb6o\xcc\xb2\xf0\x9d\x91\x9c \xd0\xb2\xe2\xb7\xa1a\xcc\xb6r\xcc\x93\xcd\x98\xcd\x8c"
client.write(fancy_text.decode("utf-8"))
assert client.read() == fancy_text.decode("utf-8")

client.write("Test with spaces")
client.write("Test write before reading previous")
time.sleep(0) # yield to the server to force reading of all messages by server.
assert client.read() == "Test with spaces"
assert client.read() == "Test write before reading previous"

client.write("quit")
assert client.read() == "quit"
queue.close()
queue.join_thread()
p.join()
Expand Down

0 comments on commit 2bcec24

Please sign in to comment.