Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize implementation of the transport #1847

Merged
merged 7 commits into from
Sep 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 84 additions & 51 deletions plugin/core/transports.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from .logging import exception_log, debug
from .types import TCP_CONNECT_TIMEOUT
from .types import TransportConfig
from .typing import Dict, Any, Optional, IO, Protocol, List, Callable, Tuple
from abc import ABCMeta, abstractmethod
from .typing import Dict, Any, Optional, IO, Protocol, Generic, List, Callable, Tuple, TypeVar, Union
from contextlib import closing
from functools import partial
from queue import Queue
Expand All @@ -18,49 +17,100 @@
import weakref


class Transport(metaclass=ABCMeta):
T = TypeVar('T')
T_contra = TypeVar('T_contra', contravariant=True)

@abstractmethod
def send(self, payload: Dict[str, Any]) -> None:
pass

@abstractmethod
class StopLoopError(Exception):
pass


class Transport(Generic[T]):

def send(self, payload: T) -> None:
raise NotImplementedError()

def close(self) -> None:
pass
raise NotImplementedError()


class TransportCallbacks(Protocol):
class TransportCallbacks(Protocol[T_contra]):

def on_transport_close(self, exit_code: int, exception: Optional[Exception]) -> None:
...

def on_payload(self, payload: Dict[str, Any]) -> None:
def on_payload(self, payload: T_contra) -> None:
...

def on_stderr_message(self, message: str) -> None:
...


class JsonRpcTransport(Transport):
class AbstractProcessor(Generic[T]):

def write_data(self, writer: IO[bytes], data: T) -> None:
raise NotImplementedError()

def read_data(self, reader: IO[bytes]) -> Optional[T]:
raise NotImplementedError()


class JsonRpcProcessor(AbstractProcessor[Dict[str, Any]]):

def write_data(self, writer: IO[bytes], data: Dict[str, Any]) -> None:
body = self._encode(data)
writer.writelines(("Content-Length: {}\r\n\r\n".format(len(body)).encode('ascii'), body))

def read_data(self, reader: IO[bytes]) -> Optional[Dict[str, Any]]:
headers = http.client.parse_headers(reader) # type: ignore
try:
body = reader.read(int(headers.get("Content-Length")))
except TypeError:
# Expected error on process stopping. Stop the read loop.
raise StopLoopError()
Comment on lines +65 to +70
Copy link
Member Author

@rchl rchl Sep 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed so that explicit action is used to stop the read loop instead of relying on TypeError doing it.

We could maybe just raise StopLoopError() earlier by checking if headers.get("Content-Length") is None.

try:
return self._decode(body)
except Exception as ex:
exception_log("JSON decode error", ex)
return None

@staticmethod
def _encode(data: Dict[str, Any]) -> bytes:
return json.dumps(
data,
ensure_ascii=False,
sort_keys=False,
check_circular=False,
separators=(',', ':')
).encode('utf-8')

@staticmethod
def _decode(message: bytes) -> Dict[str, Any]:
return json.loads(message.decode('utf-8'))


class ProcessTransport(Transport[T]):

def __init__(self, name: str, process: subprocess.Popen, socket: Optional[socket.socket], reader: IO[bytes],
writer: IO[bytes], stderr: Optional[IO[bytes]], callback_object: TransportCallbacks) -> None:
writer: IO[bytes], stderr: Optional[IO[bytes]], processor: AbstractProcessor[T],
callback_object: TransportCallbacks[T]) -> None:
self._closed = False
self._process = process
self._socket = socket
self._reader = reader
self._writer = writer
self._stderr = stderr
self._processor = processor
self._reader_thread = threading.Thread(target=self._read_loop, name='{}-reader'.format(name))
self._writer_thread = threading.Thread(target=self._write_loop, name='{}-writer'.format(name))
self._stderr_thread = threading.Thread(target=self._stderr_loop, name='{}-stderr'.format(name))
self._callback_object = weakref.ref(callback_object)
self._send_queue = Queue(0) # type: Queue[Optional[Dict[str, Any]]]
self._send_queue = Queue(0) # type: Queue[Union[T, None]]
self._reader_thread.start()
self._writer_thread.start()
self._stderr_thread.start()

def send(self, payload: Dict[str, Any]) -> None:
def send(self, payload: T) -> None:
self._send_queue.put_nowait(payload)

def close(self) -> None:
Expand All @@ -87,25 +137,17 @@ def __del__(self) -> None:
def _read_loop(self) -> None:
try:
while self._reader:
headers = http.client.parse_headers(self._reader) # type: ignore
body = self._reader.read(int(headers.get("Content-Length")))
try:
payload = _decode(body)

def invoke(p: Dict[str, Any]) -> None:
callback_object = self._callback_object()
if callback_object:
callback_object.on_payload(p)

sublime.set_timeout_async(partial(invoke, payload))
except Exception as ex:
exception_log("JSON decode error", ex)
payload = self._processor.read_data(self._reader)
if payload is None:
continue
finally:
# We don't need these anymore
del body
del headers
except (AttributeError, BrokenPipeError, TypeError):

def invoke(p: T) -> None:
callback_object = self._callback_object()
if callback_object:
callback_object.on_payload(p)

sublime.set_timeout_async(partial(invoke, payload))
except (AttributeError, BrokenPipeError, StopLoopError):
pass
except Exception as ex:
exception_log("Unexpected exception", ex)
Expand Down Expand Up @@ -146,8 +188,7 @@ def _write_loop(self) -> None:
d = self._send_queue.get()
if d is None:
break
body = _encode(d)
self._writer.writelines(("Content-Length: {}\r\n\r\n".format(len(body)).encode('ascii'), body))
self._processor.write_data(self._writer, d)
self._writer.flush()
except (BrokenPipeError, AttributeError):
pass
Expand Down Expand Up @@ -176,8 +217,12 @@ def _stderr_loop(self) -> None:
self._send_queue.put_nowait(None)


# Can be a singleton since it doesn't hold any state.
json_rpc_processor = JsonRpcProcessor()


def create_transport(config: TransportConfig, cwd: Optional[str],
callback_object: TransportCallbacks) -> JsonRpcTransport:
callback_object: TransportCallbacks) -> Transport[Dict[str, Any]]:
if config.tcp_port is not None:
assert config.tcp_port is not None
if config.tcp_port < 0:
Expand Down Expand Up @@ -214,8 +259,10 @@ def start_subprocess() -> subprocess.Popen:
else:
reader = process.stdout # type: ignore
writer = process.stdin # type: ignore
assert writer
return JsonRpcTransport(config.name, process, sock, reader, writer, process.stderr, callback_object)
if not reader or not writer:
raise RuntimeError('Failed initializing transport: reader: {}, writer: {}'.format(reader, writer))
return ProcessTransport(config.name, process, sock, reader, writer, process.stderr, json_rpc_processor,
callback_object)


_subprocesses = weakref.WeakSet() # type: weakref.WeakSet[subprocess.Popen]
Expand Down Expand Up @@ -321,17 +368,3 @@ def _connect_tcp(port: int) -> Optional[socket.socket]:
except ConnectionRefusedError:
pass
return None


def _encode(d: Dict[str, Any]) -> bytes:
return json.dumps(
d,
ensure_ascii=False,
sort_keys=False,
check_circular=False,
separators=(',', ':')
).encode('utf-8')


def _decode(message: bytes) -> Dict[str, Any]:
return json.loads(message.decode('utf-8'))
32 changes: 16 additions & 16 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from LSP.plugin.core.protocol import Point, Range, Request, Notification
from LSP.plugin.core.transports import _encode, _decode
from LSP.plugin.core.protocol import Point, Position, Range, RangeLsp, Request, Notification
from LSP.plugin.core.transports import JsonRpcProcessor
import unittest


LSP_START_POSITION = {'line': 10, 'character': 4}
LSP_END_POSITION = {'line': 11, 'character': 3}
LSP_RANGE = {'start': LSP_START_POSITION, 'end': LSP_END_POSITION}
LSP_START_POSITION = {'line': 10, 'character': 4} # type: Position
LSP_END_POSITION = {'line': 11, 'character': 3} # type: Position
LSP_RANGE = {'start': LSP_START_POSITION, 'end': LSP_END_POSITION} # type: RangeLsp
LSP_MINIMAL_DIAGNOSTIC = {
'message': 'message',
'range': LSP_RANGE
Expand All @@ -21,7 +21,7 @@

class PointTests(unittest.TestCase):

def test_lsp_conversion(self):
def test_lsp_conversion(self) -> None:
point = Point.from_lsp(LSP_START_POSITION)
self.assertEqual(point.row, 10)
self.assertEqual(point.col, 4)
Expand All @@ -32,7 +32,7 @@ def test_lsp_conversion(self):

class RangeTests(unittest.TestCase):

def test_lsp_conversion(self):
def test_lsp_conversion(self) -> None:
range = Range.from_lsp(LSP_RANGE)
self.assertEqual(range.start.row, 10)
self.assertEqual(range.start.col, 4)
Expand All @@ -44,7 +44,7 @@ def test_lsp_conversion(self):
self.assertEqual(lsp_range['end']['line'], 11)
self.assertEqual(lsp_range['end']['character'], 3)

def test_contains(self):
def test_contains(self) -> None:
range = Range.from_lsp(LSP_RANGE)
point = Point.from_lsp(LSP_START_POSITION)
self.assertTrue(range.contains(point))
Expand All @@ -67,7 +67,7 @@ def test_contains(self):
point = Point.from_lsp({'line': 0, 'character': 4})
self.assertTrue(range.contains(point))

def test_intersects(self):
def test_intersects(self) -> None:
# range2 fully contained within range1
range1 = Range.from_lsp({
'start': {'line': 0, 'character': 0},
Expand Down Expand Up @@ -128,24 +128,24 @@ def test_extend(self) -> None:


class EncodingTests(unittest.TestCase):
def test_encode(self):
encoded = _encode({"text": "😃"})
def test_encode(self) -> None:
encoded = JsonRpcProcessor._encode({"text": "😃"})
self.assertEqual(encoded, b'{"text":"\xF0\x9F\x98\x83"}')
decoded = _decode(encoded)
decoded = JsonRpcProcessor._decode(encoded)
self.assertEqual(decoded, {"text": "😃"})


class RequestTests(unittest.TestCase):

def test_initialize(self):
def test_initialize(self) -> None:
req = Request.initialize({"param": 1})
payload = req.to_payload(1)
self.assertEqual(payload["jsonrpc"], "2.0")
self.assertEqual(payload["id"], 1)
self.assertEqual(payload["method"], "initialize")
self.assertEqual(payload["params"], {"param": 1})

def test_shutdown(self):
def test_shutdown(self) -> None:
req = Request.shutdown()
payload = req.to_payload(1)
self.assertEqual(payload["jsonrpc"], "2.0")
Expand All @@ -156,15 +156,15 @@ def test_shutdown(self):

class NotificationTests(unittest.TestCase):

def test_initialized(self):
def test_initialized(self) -> None:
notification = Notification.initialized()
payload = notification.to_payload()
self.assertEqual(payload["jsonrpc"], "2.0")
self.assertNotIn("id", payload)
self.assertEqual(payload["method"], "initialized")
self.assertEqual(payload["params"], dict())

def test_exit(self):
def test_exit(self) -> None:
notification = Notification.exit()
payload = notification.to_payload()
self.assertEqual(payload["jsonrpc"], "2.0")
Expand Down