Skip to content

Commit

Permalink
Fix fifo communication for large testing projects (microsoft#24690)
Browse files Browse the repository at this point in the history
revert the revert of the old commit so now main uses fifo again
add a limit of 4096 bytes per communication sent between python
subprocess and extension
fixes microsoft#24656
  • Loading branch information
eleanorjboyd committed Jan 7, 2025
1 parent ca25769 commit a671781
Show file tree
Hide file tree
Showing 20 changed files with 478 additions and 373 deletions.
1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def install_python_libs(session: nox.Session):
)

session.install("packaging")
session.install("debugpy")

# Download get-pip script
session.run(
Expand Down
56 changes: 17 additions & 39 deletions python_files/testing_tools/socket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,24 @@ def __exit__(self, *_):
self.close()

def connect(self):
if sys.platform == "win32":
self._writer = open(self.name, "w", encoding="utf-8") # noqa: SIM115, PTH123
# reader created in read method
else:
self._socket = _SOCKET(socket.AF_UNIX, socket.SOCK_STREAM)
self._socket.connect(self.name)
self._writer = open(self.name, "w", encoding="utf-8") # noqa: SIM115, PTH123
# reader created in read method
return self

def close(self):
if sys.platform == "win32":
self._writer.close()
else:
# add exception catch
self._socket.close()
self._writer.close()
if hasattr(self, "_reader"):
self._reader.close()

def write(self, data: str):
if sys.platform == "win32":
try:
# for windows, is should only use \n\n
request = (
f"""content-length: {len(data)}\ncontent-type: application/json\n\n{data}"""
)
self._writer.write(request)
self._writer.flush()
except Exception as e:
print("error attempting to write to pipe", e)
raise (e)
else:
# must include the carriage-return defined (as \r\n) for unix systems
request = (
f"""content-length: {len(data)}\r\ncontent-type: application/json\r\n\r\n{data}"""
)
self._socket.send(request.encode("utf-8"))
try:
# for windows, is should only use \n\n
request = f"""content-length: {len(data)}\ncontent-type: application/json\n\n{data}"""
self._writer.write(request)
self._writer.flush()
except Exception as e:
print("error attempting to write to pipe", e)
raise (e)

def read(self, bufsize=1024) -> str:
"""Read data from the socket.
Expand All @@ -63,17 +48,10 @@ def read(self, bufsize=1024) -> str:
Returns:
data (str): Data received from the socket.
"""
if sys.platform == "win32":
# returns a string automatically from read
if not hasattr(self, "_reader"):
self._reader = open(self.name, encoding="utf-8") # noqa: SIM115, PTH123
return self._reader.read(bufsize)
else:
# receive bytes and convert to string
while True:
part: bytes = self._socket.recv(bufsize)
data: str = part.decode("utf-8")
return data
# returns a string automatically from read
if not hasattr(self, "_reader"):
self._reader = open(self.name, encoding="utf-8") # noqa: SIM115, PTH123
return self._reader.read(bufsize)


class SocketManager:
Expand Down
33 changes: 27 additions & 6 deletions python_files/tests/pytestadapter/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ def parse_rpc_message(data: str) -> Tuple[Dict[str, str], str]:
print("json decode error")


def _listen_on_fifo(pipe_name: str, result: List[str], completed: threading.Event):
# Open the FIFO for reading
fifo_path = pathlib.Path(pipe_name)
with fifo_path.open() as fifo:
print("Waiting for data...")
while True:
if completed.is_set():
break # Exit loop if completed event is set
data = fifo.read() # This will block until data is available
if len(data) == 0:
# If data is empty, assume EOF
break
print(f"Received: {data}")
result.append(data)


def _listen_on_pipe_new(listener, result: List[str], completed: threading.Event):
"""Listen on the named pipe or Unix domain socket for JSON data from the server.
Expand Down Expand Up @@ -307,14 +323,19 @@ def runner_with_cwd_env(
# if additional environment variables are passed, add them to the environment
if env_add:
env.update(env_add)
server = UnixPipeServer(pipe_name)
server.start()
# server = UnixPipeServer(pipe_name)
# server.start()
#################
# Create the FIFO (named pipe) if it doesn't exist
# if not pathlib.Path.exists(pipe_name):
os.mkfifo(pipe_name)
#################

completed = threading.Event()

result = [] # result is a string array to store the data during threading
t1: threading.Thread = threading.Thread(
target=_listen_on_pipe_new, args=(server, result, completed)
target=_listen_on_fifo, args=(pipe_name, result, completed)
)
t1.start()

Expand Down Expand Up @@ -364,14 +385,14 @@ def generate_random_pipe_name(prefix=""):

# For Windows, named pipes have a specific naming convention.
if sys.platform == "win32":
return f"\\\\.\\pipe\\{prefix}-{random_suffix}-sock"
return f"\\\\.\\pipe\\{prefix}-{random_suffix}"

# For Unix-like systems, use either the XDG_RUNTIME_DIR or a temporary directory.
xdg_runtime_dir = os.getenv("XDG_RUNTIME_DIR")
if xdg_runtime_dir:
return os.path.join(xdg_runtime_dir, f"{prefix}-{random_suffix}.sock") # noqa: PTH118
return os.path.join(xdg_runtime_dir, f"{prefix}-{random_suffix}") # noqa: PTH118
else:
return os.path.join(tempfile.gettempdir(), f"{prefix}-{random_suffix}.sock") # noqa: PTH118
return os.path.join(tempfile.gettempdir(), f"{prefix}-{random_suffix}") # noqa: PTH118


class UnixPipeServer:
Expand Down
19 changes: 13 additions & 6 deletions python_files/unittestadapter/pvsc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from typing_extensions import NotRequired # noqa: E402

from testing_tools import socket_manager # noqa: E402

# Types


Expand Down Expand Up @@ -331,10 +329,10 @@ def send_post_request(

if __writer is None:
try:
__writer = socket_manager.PipeManager(test_run_pipe)
__writer.connect()
__writer = open(test_run_pipe, "wb") # noqa: SIM115, PTH123
except Exception as error:
error_msg = f"Error attempting to connect to extension named pipe {test_run_pipe}[vscode-unittest]: {error}"
print(error_msg, file=sys.stderr)
__writer = None
raise VSCodeUnittestError(error_msg) from error

Expand All @@ -343,10 +341,19 @@ def send_post_request(
"params": payload,
}
data = json.dumps(rpc)

try:
if __writer:
__writer.write(data)
request = (
f"""content-length: {len(data)}\r\ncontent-type: application/json\r\n\r\n{data}"""
)
size = 4096
encoded = request.encode("utf-8")
bytes_written = 0
while bytes_written < len(encoded):
print("writing more bytes!")
segment = encoded[bytes_written : bytes_written + size]
bytes_written += __writer.write(segment)
__writer.flush()
else:
print(
f"Connection error[vscode-unittest], writer is None \n[vscode-unittest] data: \n{data} \n",
Expand Down
50 changes: 27 additions & 23 deletions python_files/vscode_pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@

import pytest

script_dir = pathlib.Path(__file__).parent.parent
sys.path.append(os.fspath(script_dir))
sys.path.append(os.fspath(script_dir / "lib" / "python"))
from testing_tools import socket_manager # noqa: E402

if TYPE_CHECKING:
from pluggy import Result

Expand Down Expand Up @@ -171,7 +166,7 @@ def pytest_exception_interact(node, call, report):
collected_test = TestRunResultDict()
collected_test[node_id] = item_result
cwd = pathlib.Path.cwd()
execution_post(
send_execution_message(
os.fsdecode(cwd),
"success",
collected_test if collected_test else None,
Expand Down Expand Up @@ -295,7 +290,7 @@ def pytest_report_teststatus(report, config): # noqa: ARG001
)
collected_test = TestRunResultDict()
collected_test[absolute_node_id] = item_result
execution_post(
send_execution_message(
os.fsdecode(cwd),
"success",
collected_test if collected_test else None,
Expand Down Expand Up @@ -329,7 +324,7 @@ def pytest_runtest_protocol(item, nextitem): # noqa: ARG001
)
collected_test = TestRunResultDict()
collected_test[absolute_node_id] = item_result
execution_post(
send_execution_message(
os.fsdecode(cwd),
"success",
collected_test if collected_test else None,
Expand Down Expand Up @@ -405,15 +400,15 @@ def pytest_sessionfinish(session, exitstatus):
"children": [],
"id_": "",
}
post_response(os.fsdecode(cwd), error_node)
send_discovery_message(os.fsdecode(cwd), error_node)
try:
session_node: TestNode | None = build_test_tree(session)
if not session_node:
raise VSCodePytestError(
"Something went wrong following pytest finish, \
no session node was created"
)
post_response(os.fsdecode(cwd), session_node)
send_discovery_message(os.fsdecode(cwd), session_node)
except Exception as e:
ERRORS.append(
f"Error Occurred, traceback: {(traceback.format_exc() if e.__traceback__ else '')}"
Expand All @@ -425,7 +420,7 @@ def pytest_sessionfinish(session, exitstatus):
"children": [],
"id_": "",
}
post_response(os.fsdecode(cwd), error_node)
send_discovery_message(os.fsdecode(cwd), error_node)
else:
if exitstatus == 0 or exitstatus == 1:
exitstatus_bool = "success"
Expand All @@ -435,7 +430,7 @@ def pytest_sessionfinish(session, exitstatus):
)
exitstatus_bool = "error"

execution_post(
send_execution_message(
os.fsdecode(cwd),
exitstatus_bool,
None,
Expand Down Expand Up @@ -485,7 +480,7 @@ def pytest_sessionfinish(session, exitstatus):
result=file_coverage_map,
error=None,
)
send_post_request(payload)
send_message(payload)


def build_test_tree(session: pytest.Session) -> TestNode:
Expand Down Expand Up @@ -853,8 +848,10 @@ def get_node_path(node: Any) -> pathlib.Path:
atexit.register(lambda: __writer.close() if __writer else None)


def execution_post(cwd: str, status: Literal["success", "error"], tests: TestRunResultDict | None):
"""Sends a POST request with execution payload details.
def send_execution_message(
cwd: str, status: Literal["success", "error"], tests: TestRunResultDict | None
):
"""Sends message execution payload details.
Args:
cwd (str): Current working directory.
Expand All @@ -866,10 +863,10 @@ def execution_post(cwd: str, status: Literal["success", "error"], tests: TestRun
)
if ERRORS:
payload["error"] = ERRORS
send_post_request(payload)
send_message(payload)


def post_response(cwd: str, session_node: TestNode) -> None:
def send_discovery_message(cwd: str, session_node: TestNode) -> None:
"""
Sends a POST request with test session details in payload.
Expand All @@ -885,7 +882,7 @@ def post_response(cwd: str, session_node: TestNode) -> None:
}
if ERRORS is not None:
payload["error"] = ERRORS
send_post_request(payload, cls_encoder=PathEncoder)
send_message(payload, cls_encoder=PathEncoder)


class PathEncoder(json.JSONEncoder):
Expand All @@ -897,7 +894,7 @@ def default(self, o):
return super().default(o)


def send_post_request(
def send_message(
payload: ExecutionPayloadDict | DiscoveryPayloadDict | CoveragePayloadDict,
cls_encoder=None,
):
Expand All @@ -922,8 +919,7 @@ def send_post_request(

if __writer is None:
try:
__writer = socket_manager.PipeManager(TEST_RUN_PIPE)
__writer.connect()
__writer = open(TEST_RUN_PIPE, "wb") # noqa: SIM115, PTH123
except Exception as error:
error_msg = f"Error attempting to connect to extension named pipe {TEST_RUN_PIPE}[vscode-pytest]: {error}"
print(error_msg, file=sys.stderr)
Expand All @@ -941,10 +937,18 @@ def send_post_request(
"params": payload,
}
data = json.dumps(rpc, cls=cls_encoder)

try:
if __writer:
__writer.write(data)
request = (
f"""content-length: {len(data)}\r\ncontent-type: application/json\r\n\r\n{data}"""
)
size = 4096
encoded = request.encode("utf-8")
bytes_written = 0
while bytes_written < len(encoded):
segment = encoded[bytes_written : bytes_written + size]
bytes_written += __writer.write(segment)
__writer.flush()
else:
print(
f"Plugin error connection error[vscode-pytest], writer is None \n[vscode-pytest] data: \n{data} \n",
Expand Down
2 changes: 2 additions & 0 deletions python_files/vscode_pytest/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# def send_post_request():
# return
Loading

0 comments on commit a671781

Please sign in to comment.