Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move JsonRPCServer.start_io to high-level asyncio API #506

Merged
merged 2 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 20 additions & 42 deletions pygls/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import asyncio
import json
import logging
import re
import sys
import typing
from threading import Event

from pygls.exceptions import PyglsError, JsonRpcException, JsonRpcInternalError
from pygls.exceptions import JsonRpcException, JsonRpcInternalError, PyglsError
from pygls.io_ import run_async
from pygls.protocol import JsonRPCProtocol, default_converter
from pygls.server import WebSocketTransportAdapter

Expand Down Expand Up @@ -104,7 +104,15 @@ async def start_io(self, cmd: str, *args, **kwargs):
raise RuntimeError("Server process is missing a stdout stream")

self.protocol.connection_made(server.stdin) # type: ignore
connection = asyncio.create_task(self.run_async(server.stdout))
connection = asyncio.create_task(
run_async(
stop_event=self._stop_event,
reader=server.stdout,
protocol=self.protocol,
logger=logger,
error_handler=self.report_server_error,
)
)
notify_exit = asyncio.create_task(self._server_exit())

self._server = server
Expand All @@ -115,7 +123,15 @@ async def start_tcp(self, host: str, port: int):
reader, writer = await asyncio.open_connection(host, port)

self.protocol.connection_made(writer) # type: ignore
connection = asyncio.create_task(self.run_async(reader))
connection = asyncio.create_task(
run_async(
stop_event=self._stop_event,
reader=reader,
protocol=self.protocol,
logger=logger,
error_handler=self.report_server_error,
)
)

self._async_tasks.extend([connection])

Expand All @@ -139,44 +155,6 @@ async def start_ws(self, host: str, port: int):
connection = asyncio.create_task(self.run_websocket(websocket))
self._async_tasks.extend([connection])

async def run_async(self, reader: asyncio.StreamReader):
"""Run the main message processing loop, asynchronously"""

CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$")
content_length = 0

while not self._stop_event.is_set():
# Read a header line
header = await reader.readline()
if not header:
break

# Extract content length if possible
if not content_length:
match = CONTENT_LENGTH_PATTERN.fullmatch(header)
if match:
content_length = int(match.group(1))
logger.debug("Content length: %s", content_length)

# Check if all headers have been read (as indicated by an empty line \r\n)
if content_length and not header.strip():
# Read body
body = await reader.readexactly(content_length)
if not body:
break

try:
message = json.loads(
body, object_hook=self.protocol.structure_message
)
self.protocol.handle_message(message)
except Exception as exc:
logger.exception("Unable to handle message")
self._report_server_error(exc, JsonRpcInternalError)

# Reset
content_length = 0

async def run_websocket(self, websocket: ClientConnection):
"""Run the main message processing loop, over websockets."""

Expand Down
193 changes: 193 additions & 0 deletions pygls/io_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
############################################################################
# Copyright(c) Open Law Library. All rights reserved. #
# See ThirdPartyNotices.txt in the project root for additional notices. #
# #
# Licensed under the Apache License, Version 2.0 (the "License") #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http: // www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
from __future__ import annotations

import asyncio
import json
import logging
import re
import typing

from pygls.exceptions import JsonRpcException

if typing.TYPE_CHECKING:
import logging
import threading
from collections.abc import Awaitable
from concurrent.futures import ThreadPoolExecutor
from typing import Any, BinaryIO, Callable, Protocol

from pygls.protocol import JsonRPCProtocol

class Reader(Protocol):
"""An synchronous reader."""

def readline(self) -> bytes: ...

def read(self, n: int) -> bytes: ...

class AsyncReader(typing.Protocol):
"""An asynchronous reader."""

def readline(self) -> Awaitable[bytes]: ...

def readexactly(self, n: int) -> Awaitable[bytes]: ...


class StdinAsyncReader:
"""Read from stdin asynchronously."""

def __init__(self, stdin: BinaryIO, executor: ThreadPoolExecutor | None = None):
self.stdin = stdin
self._loop: asyncio.AbstractEventLoop | None = None
self.executor = executor

@property
def loop(self):
if self._loop is None:
self._loop = asyncio.get_running_loop()

return self._loop

def readline(self) -> Awaitable[bytes]:
return self.loop.run_in_executor(self.executor, self.stdin.readline)

def readexactly(self, n: int) -> Awaitable[bytes]:
return self.loop.run_in_executor(self.executor, self.stdin.read, n)


async def run_async(
stop_event: threading.Event,
reader: AsyncReader,
protocol: JsonRPCProtocol,
logger: logging.Logger | None = None,
error_handler: Callable[[Exception, type[JsonRpcException]], Any] | None = None,
):
"""Run a main message processing loop, asynchronously

Parameters
----------
stop_event
A ``threading.Event`` used to break the main loop

reader
The reader to read messages from

protocol
The protocol instance that should handle the messages

logger
The logger instance to use
"""

CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$")
content_length = 0
logger = logger or logging.getLogger(__name__)

while not stop_event.is_set():
# Read a header line
header = await reader.readline()
if not header:
break

# Extract content length if possible
if not content_length:
match = CONTENT_LENGTH_PATTERN.fullmatch(header)
if match:
content_length = int(match.group(1))
logger.debug("Content length: %s", content_length)

# Check if all headers have been read (as indicated by an empty line \r\n)
if content_length and not header.strip():
# Read body
body = await reader.readexactly(content_length)
if not body:
break

try:
message = json.loads(body, object_hook=protocol.structure_message)
protocol.handle_message(message)
except Exception as exc:
logger.exception("Unable to handle message")
if error_handler:
error_handler(exc, JsonRpcException)
finally:
# Reset
content_length = 0


def run(
stop_event: threading.Event,
reader: Reader,
protocol: JsonRPCProtocol,
logger: logging.Logger | None = None,
error_handler: Callable[[Exception, type[JsonRpcException]], Any] | None = None,
):
"""Run a main message processing loop, synchronously

Parameters
----------
stop_event
A ``threading.Event`` used to break the main loop

reader
The reader to read messages from

protocol
The protocol instance that should handle the messages

logger
The logger instance to use

error_handler
Function to call when an error is encountered.
"""

CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$")
content_length = 0
logger = logger or logging.getLogger(__name__)

while not stop_event.is_set():
# Read a header line
header = reader.readline()
if not header:
break

# Extract content length if possible
if not content_length:
match = CONTENT_LENGTH_PATTERN.fullmatch(header)
if match:
content_length = int(match.group(1))
logger.debug("Content length: %s", content_length)

# Check if all headers have been read (as indicated by an empty line \r\n)
if content_length and not header.strip():
# Read body
body = reader.read(content_length)
if not body:
break

try:
message = json.loads(body, object_hook=protocol.structure_message)
protocol.handle_message(message)
except Exception as exc:
logger.exception("Unable to handle message")
if error_handler:
error_handler(exc, JsonRpcException)
finally:
# Reset
content_length = 0
74 changes: 12 additions & 62 deletions pygls/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,73 +19,22 @@
import asyncio
import json
import logging
import re
import sys
from concurrent.futures import ThreadPoolExecutor
from threading import Event
from typing import Any, BinaryIO, Callable, Optional, Type, TypeVar, Union

import cattrs
from pygls.exceptions import (
FeatureNotificationError,
JsonRpcInternalError,
PyglsError,
JsonRpcException,
FeatureRequestError,
)
from pygls.protocol import JsonRPCProtocol

from pygls.exceptions import JsonRpcException, PyglsError
from pygls.io_ import StdinAsyncReader, run_async
from pygls.protocol import JsonRPCProtocol

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable)

ServerErrors = Union[
PyglsError,
JsonRpcException,
Type[JsonRpcInternalError],
Type[FeatureNotificationError],
Type[FeatureRequestError],
]


async def aio_readline(loop, executor, stop_event, rfile, proxy):
"""Reads data from stdin in separate thread (asynchronously)."""

CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$")

# Initialize message buffer
message = []
content_length = 0

while not stop_event.is_set() and not rfile.closed:
# Read a header line
header = await loop.run_in_executor(executor, rfile.readline)
if not header:
break
message.append(header)

# Extract content length if possible
if not content_length:
match = CONTENT_LENGTH_PATTERN.fullmatch(header)
if match:
content_length = int(match.group(1))
logger.debug("Content length: %s", content_length)

# Check if all headers have been read (as indicated by an empty line \r\n)
if content_length and not header.strip():
# Read body
body = await loop.run_in_executor(executor, rfile.read, content_length)
if not body:
break
message.append(body)

# Pass message to language server protocol
proxy(b"".join(message))

# Reset the buffer
message = []
content_length = 0
ServerErrors = Union[type[PyglsError], type[JsonRpcException]]


class StdOutTransportAdapter:
Expand Down Expand Up @@ -228,19 +177,20 @@ def start_io(
logger.info("Starting IO server")

self._stop_event = Event()
reader = StdinAsyncReader(stdin or sys.stdin.buffer, self.thread_pool)
transport = StdOutTransportAdapter(
stdin or sys.stdin.buffer, stdout or sys.stdout.buffer
)
self.protocol.connection_made(transport) # type: ignore[arg-type]

try:
self.loop.run_until_complete(
aio_readline(
self.loop,
self.thread_pool,
self._stop_event,
stdin or sys.stdin.buffer,
self.protocol.data_received,
asyncio.run(
run_async(
stop_event=self._stop_event,
reader=reader,
protocol=self.protocol,
logger=logger,
error_handler=self.report_server_error,
)
)
except BrokenPipeError:
Expand Down
Loading
Loading