Skip to content

Commit

Permalink
Better tests + fix bug with multiple frames in buffer + base64 encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
svalentin committed Oct 13, 2023
1 parent c160551 commit 73ecb8a
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 41 deletions.
19 changes: 2 additions & 17 deletions mypy/dmypy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import time
import traceback
from contextlib import redirect_stderr, redirect_stdout
from typing import AbstractSet, Any, Callable, Final, Iterable, List, Sequence, Tuple
from typing import AbstractSet, Any, Callable, Final, List, Sequence, Tuple
from typing_extensions import TypeAlias as _TypeAlias

import mypy.build
import mypy.errors
import mypy.main
from mypy.dmypy_util import receive, send
from mypy.dmypy_util import receive, send, WriteToConn
from mypy.find_sources import InvalidSourceList, create_source_list
from mypy.fscache import FileSystemCache
from mypy.fswatcher import FileData, FileSystemWatcher
Expand Down Expand Up @@ -209,21 +209,6 @@ def _response_metadata(self) -> dict[str, str]:
def serve(self) -> None:
"""Serve requests, synchronously (no thread or fork)."""

class WriteToConn(object):
def __init__(self, server: IPCBase, output_key: str = "stdout"):
self.server = server
self.output_key = output_key

def write(self, output: str) -> int:
resp: dict[str, Any] = {}
resp[self.output_key] = output
send(server, resp)
return len(output)

def writelines(self, lines: Iterable[str]) -> None:
for s in lines:
self.write(s)

command = None
server = IPCServer(CONNECTION_NAME, self.timeout)

Expand Down
23 changes: 20 additions & 3 deletions mypy/dmypy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from __future__ import annotations

import json
from typing import Any, Final
from typing import Any, Final, Iterable
from io import TextIOWrapper

from mypy.ipc import IPCBase

Expand Down Expand Up @@ -35,7 +36,23 @@ def send(connection: IPCBase, data: Any) -> None:
The data must be JSON-serializable. We assume that a single send call is a
single frame to be sent on the connect.
As an easy way to separate frames, we urlencode them and separate by space.
Last frame also has a trailing space.
"""
connection.write(json.dumps(data))


class WriteToConn(object):
"""Helper class to write to a connection instead of standard output."""

def __init__(self, server: IPCBase, output_key: str = "stdout"):
self.server = server
self.output_key = output_key

def write(self, output: str) -> int:
resp: dict[str, Any] = {}
resp[self.output_key] = output
send(self.server, resp)
return len(output)

def writelines(self, lines: Iterable[str]) -> None:
for s in lines:
self.write(s)
58 changes: 41 additions & 17 deletions mypy/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tempfile
from types import TracebackType
from typing import Callable, Final
from urllib.parse import quote,unquote
import codecs

if sys.platform == "win32":
# This may be private, but it is needed for IPC on Windows, and is basically stable
Expand Down Expand Up @@ -41,6 +41,10 @@ class IPCBase:
This contains logic shared between the client and server, such as reading
and writing.
We want to be able to send multiple "messages" over a single connection and
to be able to separate the messages. We do this by encoding the messages
in an alphabet that does not contain spaces, then adding a space for
separation. The last framed message is also followed by a space.
"""

connection: _IPCHandle
Expand All @@ -50,11 +54,28 @@ def __init__(self, name: str, timeout: float | None) -> None:
self.timeout = timeout
self.buffer = bytearray()

def frame_from_buffer(self) -> bytearray | None:
"""Return a full frame from the bytes we have in the buffer."""
space_pos = self.buffer.find(b" ")
if space_pos == -1:
return None
# We have a full frame
bdata = self.buffer[: space_pos]
self.buffer = self.buffer[space_pos + 1 :]
return bdata

def read(self, size: int = 100000) -> str:
"""Read bytes from an IPC connection until its empty."""
bdata = bytearray()
"""Read bytes from an IPC connection until we have a full frame."""
bdata : bytearray | None = bytearray()
if sys.platform == "win32":
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break

# Receive more data into the buffer.
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
Expand All @@ -69,11 +90,8 @@ def read(self, size: int = 100000) -> str:
more = ov.getbuffer()
if more:
self.buffer.extend(more)
space_pos = self.buffer.find(b" ")
if space_pos != -1:
# We have a full frame
bdata = self.buffer[: space_pos]
self.buffer = self.buffer[space_pos + 1 :]
bdata = self.frame_from_buffer()
if bdata is not None:
break
if err == 0:
# we are done!
Expand All @@ -85,24 +103,30 @@ def read(self, size: int = 100000) -> str:
raise IPCException("ReadFile operation aborted.")
else:
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break

# Receive more data into the buffer.
more = self.connection.recv(size)
if not more:
# Connection closed
break
self.buffer.extend(more)
space_pos = self.buffer.find(b" ")
if space_pos != -1:
# We have a full frame
bdata = self.buffer[: space_pos]
self.buffer = self.buffer[space_pos + 1 :]
break
return unquote(bytes(bdata).decode("utf8"))

if not bdata:
# Socket was empty and we didn't get any frame.
# This should only happen if the socket was closed.
return ''
return codecs.decode(bdata, 'base64').decode("utf8")

def write(self, data: str) -> None:
"""Write bytes to an IPC connection."""
"""Write to an IPC connection."""

# Frame the data by urlencoding it and separating by space.
encoded_data = (quote(data) + " ").encode("utf8")
encoded_data = codecs.encode(data.encode("utf8"), 'base64') + b' '

if sys.platform == "win32":
try:
Expand Down
16 changes: 12 additions & 4 deletions mypy/test/testipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,18 @@ def test_multiple_messages(self) -> None:
p.start()
connection_name = queue.get()
with IPCClient(connection_name, timeout=1) as client:
client.write("test1")
assert client.read() == "test1"
client.write("test2")
assert client.read() == "test2"
# "foo bar" with extra accents on letters.
# In UTF-8 encoding so we don't confuse editors opening this file.
fancy_text = b'f\xcc\xb6o\xcc\xb2\xf0\x9d\x91\x9c \xd0\xb2\xe2\xb7\xa1a\xcc\xb6r\xcc\x93\xcd\x98\xcd\x8c'
client.write(fancy_text.decode('utf-8'))
assert client.read() == fancy_text.decode('utf-8')

client.write("Test with spaces")
client.write("Test write before reading previous")
time.sleep(0) # yield to the server to force reading of all messages by server.
assert client.read() == "Test with spaces"
assert client.read() == "Test write before reading previous"

client.write("quit")
assert client.read() == "quit"
queue.close()
Expand Down

0 comments on commit 73ecb8a

Please sign in to comment.