Skip to content

Commit

Permalink
Generalize implementation of the transport (#1847)
Browse files Browse the repository at this point in the history
* Generalize implementation of the transport

Extracted the JSON RPC-specific handling into a separate processor
and used composition concept to pass and use it within the transport.

That should allow re-using the transport implementation in the file
watcher to implement a transport that uses plain lines of text rather
than JSON RPC.

Not using abstract classes anymore as that is not compatible with
generic types.
  • Loading branch information
rchl authored Sep 18, 2021
1 parent 5b32049 commit 8ff07b0
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 67 deletions.
135 changes: 84 additions & 51 deletions plugin/core/transports.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from .logging import exception_log, debug
from .types import TCP_CONNECT_TIMEOUT
from .types import TransportConfig
from .typing import Dict, Any, Optional, IO, Protocol, List, Callable, Tuple
from abc import ABCMeta, abstractmethod
from .typing import Dict, Any, Optional, IO, Protocol, Generic, List, Callable, Tuple, TypeVar, Union
from contextlib import closing
from functools import partial
from queue import Queue
Expand All @@ -18,49 +17,100 @@
import weakref


class Transport(metaclass=ABCMeta):
T = TypeVar('T')
T_contra = TypeVar('T_contra', contravariant=True)

@abstractmethod
def send(self, payload: Dict[str, Any]) -> None:
pass

@abstractmethod
class StopLoopError(Exception):
pass


class Transport(Generic[T]):

def send(self, payload: T) -> None:
raise NotImplementedError()

def close(self) -> None:
pass
raise NotImplementedError()


class TransportCallbacks(Protocol):
class TransportCallbacks(Protocol[T_contra]):

def on_transport_close(self, exit_code: int, exception: Optional[Exception]) -> None:
...

def on_payload(self, payload: Dict[str, Any]) -> None:
def on_payload(self, payload: T_contra) -> None:
...

def on_stderr_message(self, message: str) -> None:
...


class JsonRpcTransport(Transport):
class AbstractProcessor(Generic[T]):

def write_data(self, writer: IO[bytes], data: T) -> None:
raise NotImplementedError()

def read_data(self, reader: IO[bytes]) -> Optional[T]:
raise NotImplementedError()


class JsonRpcProcessor(AbstractProcessor[Dict[str, Any]]):

def write_data(self, writer: IO[bytes], data: Dict[str, Any]) -> None:
body = self._encode(data)
writer.writelines(("Content-Length: {}\r\n\r\n".format(len(body)).encode('ascii'), body))

def read_data(self, reader: IO[bytes]) -> Optional[Dict[str, Any]]:
headers = http.client.parse_headers(reader) # type: ignore
try:
body = reader.read(int(headers.get("Content-Length")))
except TypeError:
# Expected error on process stopping. Stop the read loop.
raise StopLoopError()
try:
return self._decode(body)
except Exception as ex:
exception_log("JSON decode error", ex)
return None

@staticmethod
def _encode(data: Dict[str, Any]) -> bytes:
return json.dumps(
data,
ensure_ascii=False,
sort_keys=False,
check_circular=False,
separators=(',', ':')
).encode('utf-8')

@staticmethod
def _decode(message: bytes) -> Dict[str, Any]:
return json.loads(message.decode('utf-8'))


class ProcessTransport(Transport[T]):

def __init__(self, name: str, process: subprocess.Popen, socket: Optional[socket.socket], reader: IO[bytes],
writer: IO[bytes], stderr: Optional[IO[bytes]], callback_object: TransportCallbacks) -> None:
writer: IO[bytes], stderr: Optional[IO[bytes]], processor: AbstractProcessor[T],
callback_object: TransportCallbacks[T]) -> None:
self._closed = False
self._process = process
self._socket = socket
self._reader = reader
self._writer = writer
self._stderr = stderr
self._processor = processor
self._reader_thread = threading.Thread(target=self._read_loop, name='{}-reader'.format(name))
self._writer_thread = threading.Thread(target=self._write_loop, name='{}-writer'.format(name))
self._stderr_thread = threading.Thread(target=self._stderr_loop, name='{}-stderr'.format(name))
self._callback_object = weakref.ref(callback_object)
self._send_queue = Queue(0) # type: Queue[Optional[Dict[str, Any]]]
self._send_queue = Queue(0) # type: Queue[Union[T, None]]
self._reader_thread.start()
self._writer_thread.start()
self._stderr_thread.start()

def send(self, payload: Dict[str, Any]) -> None:
def send(self, payload: T) -> None:
self._send_queue.put_nowait(payload)

def close(self) -> None:
Expand All @@ -87,25 +137,17 @@ def __del__(self) -> None:
def _read_loop(self) -> None:
try:
while self._reader:
headers = http.client.parse_headers(self._reader) # type: ignore
body = self._reader.read(int(headers.get("Content-Length")))
try:
payload = _decode(body)

def invoke(p: Dict[str, Any]) -> None:
callback_object = self._callback_object()
if callback_object:
callback_object.on_payload(p)

sublime.set_timeout_async(partial(invoke, payload))
except Exception as ex:
exception_log("JSON decode error", ex)
payload = self._processor.read_data(self._reader)
if payload is None:
continue
finally:
# We don't need these anymore
del body
del headers
except (AttributeError, BrokenPipeError, TypeError):

def invoke(p: T) -> None:
callback_object = self._callback_object()
if callback_object:
callback_object.on_payload(p)

sublime.set_timeout_async(partial(invoke, payload))
except (AttributeError, BrokenPipeError, StopLoopError):
pass
except Exception as ex:
exception_log("Unexpected exception", ex)
Expand Down Expand Up @@ -146,8 +188,7 @@ def _write_loop(self) -> None:
d = self._send_queue.get()
if d is None:
break
body = _encode(d)
self._writer.writelines(("Content-Length: {}\r\n\r\n".format(len(body)).encode('ascii'), body))
self._processor.write_data(self._writer, d)
self._writer.flush()
except (BrokenPipeError, AttributeError):
pass
Expand Down Expand Up @@ -176,8 +217,12 @@ def _stderr_loop(self) -> None:
self._send_queue.put_nowait(None)


# Can be a singleton since it doesn't hold any state.
json_rpc_processor = JsonRpcProcessor()


def create_transport(config: TransportConfig, cwd: Optional[str],
callback_object: TransportCallbacks) -> JsonRpcTransport:
callback_object: TransportCallbacks) -> Transport[Dict[str, Any]]:
if config.tcp_port is not None:
assert config.tcp_port is not None
if config.tcp_port < 0:
Expand Down Expand Up @@ -214,8 +259,10 @@ def start_subprocess() -> subprocess.Popen:
else:
reader = process.stdout # type: ignore
writer = process.stdin # type: ignore
assert writer
return JsonRpcTransport(config.name, process, sock, reader, writer, process.stderr, callback_object)
if not reader or not writer:
raise RuntimeError('Failed initializing transport: reader: {}, writer: {}'.format(reader, writer))
return ProcessTransport(config.name, process, sock, reader, writer, process.stderr, json_rpc_processor,
callback_object)


_subprocesses = weakref.WeakSet() # type: weakref.WeakSet[subprocess.Popen]
Expand Down Expand Up @@ -321,17 +368,3 @@ def _connect_tcp(port: int) -> Optional[socket.socket]:
except ConnectionRefusedError:
pass
return None


def _encode(d: Dict[str, Any]) -> bytes:
return json.dumps(
d,
ensure_ascii=False,
sort_keys=False,
check_circular=False,
separators=(',', ':')
).encode('utf-8')


def _decode(message: bytes) -> Dict[str, Any]:
return json.loads(message.decode('utf-8'))
32 changes: 16 additions & 16 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from LSP.plugin.core.protocol import Point, Range, Request, Notification
from LSP.plugin.core.transports import _encode, _decode
from LSP.plugin.core.protocol import Point, Position, Range, RangeLsp, Request, Notification
from LSP.plugin.core.transports import JsonRpcProcessor
import unittest


LSP_START_POSITION = {'line': 10, 'character': 4}
LSP_END_POSITION = {'line': 11, 'character': 3}
LSP_RANGE = {'start': LSP_START_POSITION, 'end': LSP_END_POSITION}
LSP_START_POSITION = {'line': 10, 'character': 4} # type: Position
LSP_END_POSITION = {'line': 11, 'character': 3} # type: Position
LSP_RANGE = {'start': LSP_START_POSITION, 'end': LSP_END_POSITION} # type: RangeLsp
LSP_MINIMAL_DIAGNOSTIC = {
'message': 'message',
'range': LSP_RANGE
Expand All @@ -21,7 +21,7 @@

class PointTests(unittest.TestCase):

def test_lsp_conversion(self):
def test_lsp_conversion(self) -> None:
point = Point.from_lsp(LSP_START_POSITION)
self.assertEqual(point.row, 10)
self.assertEqual(point.col, 4)
Expand All @@ -32,7 +32,7 @@ def test_lsp_conversion(self):

class RangeTests(unittest.TestCase):

def test_lsp_conversion(self):
def test_lsp_conversion(self) -> None:
range = Range.from_lsp(LSP_RANGE)
self.assertEqual(range.start.row, 10)
self.assertEqual(range.start.col, 4)
Expand All @@ -44,7 +44,7 @@ def test_lsp_conversion(self):
self.assertEqual(lsp_range['end']['line'], 11)
self.assertEqual(lsp_range['end']['character'], 3)

def test_contains(self):
def test_contains(self) -> None:
range = Range.from_lsp(LSP_RANGE)
point = Point.from_lsp(LSP_START_POSITION)
self.assertTrue(range.contains(point))
Expand All @@ -67,7 +67,7 @@ def test_contains(self):
point = Point.from_lsp({'line': 0, 'character': 4})
self.assertTrue(range.contains(point))

def test_intersects(self):
def test_intersects(self) -> None:
# range2 fully contained within range1
range1 = Range.from_lsp({
'start': {'line': 0, 'character': 0},
Expand Down Expand Up @@ -128,24 +128,24 @@ def test_extend(self) -> None:


class EncodingTests(unittest.TestCase):
def test_encode(self):
encoded = _encode({"text": "😃"})
def test_encode(self) -> None:
encoded = JsonRpcProcessor._encode({"text": "😃"})
self.assertEqual(encoded, b'{"text":"\xF0\x9F\x98\x83"}')
decoded = _decode(encoded)
decoded = JsonRpcProcessor._decode(encoded)
self.assertEqual(decoded, {"text": "😃"})


class RequestTests(unittest.TestCase):

def test_initialize(self):
def test_initialize(self) -> None:
req = Request.initialize({"param": 1})
payload = req.to_payload(1)
self.assertEqual(payload["jsonrpc"], "2.0")
self.assertEqual(payload["id"], 1)
self.assertEqual(payload["method"], "initialize")
self.assertEqual(payload["params"], {"param": 1})

def test_shutdown(self):
def test_shutdown(self) -> None:
req = Request.shutdown()
payload = req.to_payload(1)
self.assertEqual(payload["jsonrpc"], "2.0")
Expand All @@ -156,15 +156,15 @@ def test_shutdown(self):

class NotificationTests(unittest.TestCase):

def test_initialized(self):
def test_initialized(self) -> None:
notification = Notification.initialized()
payload = notification.to_payload()
self.assertEqual(payload["jsonrpc"], "2.0")
self.assertNotIn("id", payload)
self.assertEqual(payload["method"], "initialized")
self.assertEqual(payload["params"], dict())

def test_exit(self):
def test_exit(self) -> None:
notification = Notification.exit()
payload = notification.to_payload()
self.assertEqual(payload["jsonrpc"], "2.0")
Expand Down

0 comments on commit 8ff07b0

Please sign in to comment.