From 8ff07b0512f2c3ca5beb4c3f991a7a8dad35f90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Ch=C5=82odnicki?= Date: Sat, 18 Sep 2021 22:23:27 +0200 Subject: [PATCH] Generalize implementation of the transport (#1847) * Generalize implementation of the transport Extracted the JSON RPC-specific handling into a separate processor and used composition concept to pass and use it within the transport. That should allow re-using the transport implementation in the file watcher to implement a transport that uses plain lines of text rather than JSON RPC. Not using abstract classes anymore as that is not compatible with generic types. --- plugin/core/transports.py | 135 ++++++++++++++++++++++++-------------- tests/test_protocol.py | 32 ++++----- 2 files changed, 100 insertions(+), 67 deletions(-) diff --git a/plugin/core/transports.py b/plugin/core/transports.py index f1c894414..1d0eabdfd 100644 --- a/plugin/core/transports.py +++ b/plugin/core/transports.py @@ -1,8 +1,7 @@ from .logging import exception_log, debug from .types import TCP_CONNECT_TIMEOUT from .types import TransportConfig -from .typing import Dict, Any, Optional, IO, Protocol, List, Callable, Tuple -from abc import ABCMeta, abstractmethod +from .typing import Dict, Any, Optional, IO, Protocol, Generic, List, Callable, Tuple, TypeVar, Union from contextlib import closing from functools import partial from queue import Queue @@ -18,49 +17,100 @@ import weakref -class Transport(metaclass=ABCMeta): +T = TypeVar('T') +T_contra = TypeVar('T_contra', contravariant=True) - @abstractmethod - def send(self, payload: Dict[str, Any]) -> None: - pass - @abstractmethod +class StopLoopError(Exception): + pass + + +class Transport(Generic[T]): + + def send(self, payload: T) -> None: + raise NotImplementedError() + def close(self) -> None: - pass + raise NotImplementedError() -class TransportCallbacks(Protocol): +class TransportCallbacks(Protocol[T_contra]): def on_transport_close(self, exit_code: int, exception: Optional[Exception]) -> None: ... - def on_payload(self, payload: Dict[str, Any]) -> None: + def on_payload(self, payload: T_contra) -> None: ... def on_stderr_message(self, message: str) -> None: ... -class JsonRpcTransport(Transport): +class AbstractProcessor(Generic[T]): + + def write_data(self, writer: IO[bytes], data: T) -> None: + raise NotImplementedError() + + def read_data(self, reader: IO[bytes]) -> Optional[T]: + raise NotImplementedError() + + +class JsonRpcProcessor(AbstractProcessor[Dict[str, Any]]): + + def write_data(self, writer: IO[bytes], data: Dict[str, Any]) -> None: + body = self._encode(data) + writer.writelines(("Content-Length: {}\r\n\r\n".format(len(body)).encode('ascii'), body)) + + def read_data(self, reader: IO[bytes]) -> Optional[Dict[str, Any]]: + headers = http.client.parse_headers(reader) # type: ignore + try: + body = reader.read(int(headers.get("Content-Length"))) + except TypeError: + # Expected error on process stopping. Stop the read loop. + raise StopLoopError() + try: + return self._decode(body) + except Exception as ex: + exception_log("JSON decode error", ex) + return None + + @staticmethod + def _encode(data: Dict[str, Any]) -> bytes: + return json.dumps( + data, + ensure_ascii=False, + sort_keys=False, + check_circular=False, + separators=(',', ':') + ).encode('utf-8') + + @staticmethod + def _decode(message: bytes) -> Dict[str, Any]: + return json.loads(message.decode('utf-8')) + + +class ProcessTransport(Transport[T]): def __init__(self, name: str, process: subprocess.Popen, socket: Optional[socket.socket], reader: IO[bytes], - writer: IO[bytes], stderr: Optional[IO[bytes]], callback_object: TransportCallbacks) -> None: + writer: IO[bytes], stderr: Optional[IO[bytes]], processor: AbstractProcessor[T], + callback_object: TransportCallbacks[T]) -> None: self._closed = False self._process = process self._socket = socket self._reader = reader self._writer = writer self._stderr = stderr + self._processor = processor self._reader_thread = threading.Thread(target=self._read_loop, name='{}-reader'.format(name)) self._writer_thread = threading.Thread(target=self._write_loop, name='{}-writer'.format(name)) self._stderr_thread = threading.Thread(target=self._stderr_loop, name='{}-stderr'.format(name)) self._callback_object = weakref.ref(callback_object) - self._send_queue = Queue(0) # type: Queue[Optional[Dict[str, Any]]] + self._send_queue = Queue(0) # type: Queue[Union[T, None]] self._reader_thread.start() self._writer_thread.start() self._stderr_thread.start() - def send(self, payload: Dict[str, Any]) -> None: + def send(self, payload: T) -> None: self._send_queue.put_nowait(payload) def close(self) -> None: @@ -87,25 +137,17 @@ def __del__(self) -> None: def _read_loop(self) -> None: try: while self._reader: - headers = http.client.parse_headers(self._reader) # type: ignore - body = self._reader.read(int(headers.get("Content-Length"))) - try: - payload = _decode(body) - - def invoke(p: Dict[str, Any]) -> None: - callback_object = self._callback_object() - if callback_object: - callback_object.on_payload(p) - - sublime.set_timeout_async(partial(invoke, payload)) - except Exception as ex: - exception_log("JSON decode error", ex) + payload = self._processor.read_data(self._reader) + if payload is None: continue - finally: - # We don't need these anymore - del body - del headers - except (AttributeError, BrokenPipeError, TypeError): + + def invoke(p: T) -> None: + callback_object = self._callback_object() + if callback_object: + callback_object.on_payload(p) + + sublime.set_timeout_async(partial(invoke, payload)) + except (AttributeError, BrokenPipeError, StopLoopError): pass except Exception as ex: exception_log("Unexpected exception", ex) @@ -146,8 +188,7 @@ def _write_loop(self) -> None: d = self._send_queue.get() if d is None: break - body = _encode(d) - self._writer.writelines(("Content-Length: {}\r\n\r\n".format(len(body)).encode('ascii'), body)) + self._processor.write_data(self._writer, d) self._writer.flush() except (BrokenPipeError, AttributeError): pass @@ -176,8 +217,12 @@ def _stderr_loop(self) -> None: self._send_queue.put_nowait(None) +# Can be a singleton since it doesn't hold any state. +json_rpc_processor = JsonRpcProcessor() + + def create_transport(config: TransportConfig, cwd: Optional[str], - callback_object: TransportCallbacks) -> JsonRpcTransport: + callback_object: TransportCallbacks) -> Transport[Dict[str, Any]]: if config.tcp_port is not None: assert config.tcp_port is not None if config.tcp_port < 0: @@ -214,8 +259,10 @@ def start_subprocess() -> subprocess.Popen: else: reader = process.stdout # type: ignore writer = process.stdin # type: ignore - assert writer - return JsonRpcTransport(config.name, process, sock, reader, writer, process.stderr, callback_object) + if not reader or not writer: + raise RuntimeError('Failed initializing transport: reader: {}, writer: {}'.format(reader, writer)) + return ProcessTransport(config.name, process, sock, reader, writer, process.stderr, json_rpc_processor, + callback_object) _subprocesses = weakref.WeakSet() # type: weakref.WeakSet[subprocess.Popen] @@ -321,17 +368,3 @@ def _connect_tcp(port: int) -> Optional[socket.socket]: except ConnectionRefusedError: pass return None - - -def _encode(d: Dict[str, Any]) -> bytes: - return json.dumps( - d, - ensure_ascii=False, - sort_keys=False, - check_circular=False, - separators=(',', ':') - ).encode('utf-8') - - -def _decode(message: bytes) -> Dict[str, Any]: - return json.loads(message.decode('utf-8')) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 953a24110..8e14f1d6e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,11 +1,11 @@ -from LSP.plugin.core.protocol import Point, Range, Request, Notification -from LSP.plugin.core.transports import _encode, _decode +from LSP.plugin.core.protocol import Point, Position, Range, RangeLsp, Request, Notification +from LSP.plugin.core.transports import JsonRpcProcessor import unittest -LSP_START_POSITION = {'line': 10, 'character': 4} -LSP_END_POSITION = {'line': 11, 'character': 3} -LSP_RANGE = {'start': LSP_START_POSITION, 'end': LSP_END_POSITION} +LSP_START_POSITION = {'line': 10, 'character': 4} # type: Position +LSP_END_POSITION = {'line': 11, 'character': 3} # type: Position +LSP_RANGE = {'start': LSP_START_POSITION, 'end': LSP_END_POSITION} # type: RangeLsp LSP_MINIMAL_DIAGNOSTIC = { 'message': 'message', 'range': LSP_RANGE @@ -21,7 +21,7 @@ class PointTests(unittest.TestCase): - def test_lsp_conversion(self): + def test_lsp_conversion(self) -> None: point = Point.from_lsp(LSP_START_POSITION) self.assertEqual(point.row, 10) self.assertEqual(point.col, 4) @@ -32,7 +32,7 @@ def test_lsp_conversion(self): class RangeTests(unittest.TestCase): - def test_lsp_conversion(self): + def test_lsp_conversion(self) -> None: range = Range.from_lsp(LSP_RANGE) self.assertEqual(range.start.row, 10) self.assertEqual(range.start.col, 4) @@ -44,7 +44,7 @@ def test_lsp_conversion(self): self.assertEqual(lsp_range['end']['line'], 11) self.assertEqual(lsp_range['end']['character'], 3) - def test_contains(self): + def test_contains(self) -> None: range = Range.from_lsp(LSP_RANGE) point = Point.from_lsp(LSP_START_POSITION) self.assertTrue(range.contains(point)) @@ -67,7 +67,7 @@ def test_contains(self): point = Point.from_lsp({'line': 0, 'character': 4}) self.assertTrue(range.contains(point)) - def test_intersects(self): + def test_intersects(self) -> None: # range2 fully contained within range1 range1 = Range.from_lsp({ 'start': {'line': 0, 'character': 0}, @@ -128,16 +128,16 @@ def test_extend(self) -> None: class EncodingTests(unittest.TestCase): - def test_encode(self): - encoded = _encode({"text": "😃"}) + def test_encode(self) -> None: + encoded = JsonRpcProcessor._encode({"text": "😃"}) self.assertEqual(encoded, b'{"text":"\xF0\x9F\x98\x83"}') - decoded = _decode(encoded) + decoded = JsonRpcProcessor._decode(encoded) self.assertEqual(decoded, {"text": "😃"}) class RequestTests(unittest.TestCase): - def test_initialize(self): + def test_initialize(self) -> None: req = Request.initialize({"param": 1}) payload = req.to_payload(1) self.assertEqual(payload["jsonrpc"], "2.0") @@ -145,7 +145,7 @@ def test_initialize(self): self.assertEqual(payload["method"], "initialize") self.assertEqual(payload["params"], {"param": 1}) - def test_shutdown(self): + def test_shutdown(self) -> None: req = Request.shutdown() payload = req.to_payload(1) self.assertEqual(payload["jsonrpc"], "2.0") @@ -156,7 +156,7 @@ def test_shutdown(self): class NotificationTests(unittest.TestCase): - def test_initialized(self): + def test_initialized(self) -> None: notification = Notification.initialized() payload = notification.to_payload() self.assertEqual(payload["jsonrpc"], "2.0") @@ -164,7 +164,7 @@ def test_initialized(self): self.assertEqual(payload["method"], "initialized") self.assertEqual(payload["params"], dict()) - def test_exit(self): + def test_exit(self) -> None: notification = Notification.exit() payload = notification.to_payload() self.assertEqual(payload["jsonrpc"], "2.0")