From d1b1271b5486df00cb240ea9a9eff3f1307c7fa1 Mon Sep 17 00:00:00 2001 From: Thomas Buckley-Houston Date: Tue, 25 Oct 2022 20:24:20 -0300 Subject: [PATCH] fix: Improved error handling Notable behaviour change: All errors outside of LSP requests are now sent to the client as `showMessage.type = MessageType.Error` messages. This may result in existing custom LSP servers showing errors that were not seen before, these aren't new errors, just previously undisplayed errors. `JsonRPCProtocol.data_receieved()` now catches unhandled errors. Fixes #277 --- docs/source/pages/advanced_usage.rst | 13 +++ pygls/exceptions.py | 8 ++ pygls/protocol.py | 52 +++++++--- pygls/server.py | 40 +++++++- tests/ls_setup.py | 12 +-- tests/lsp/test_errors.py | 136 +++++++++++++++++++++++++++ 6 files changed, 242 insertions(+), 19 deletions(-) create mode 100644 tests/lsp/test_errors.py diff --git a/docs/source/pages/advanced_usage.rst b/docs/source/pages/advanced_usage.rst index 3a89bb05..89f427cf 100644 --- a/docs/source/pages/advanced_usage.rst +++ b/docs/source/pages/advanced_usage.rst @@ -447,6 +447,19 @@ And method invocation example: server.send_notification('myCustomNotification', 'test data') +Custom Error Reporting +^^^^^^^^^^^^^^^^^^^^^^ + +By default Pygls notifies the client to display any occurences of uncaught exceptions in the +server. To override this behaviour define your own `report_server_error()` method like so: + +.. code:: python + + Class CustomLanguageServer(LanguageServer): + def report_server_error(self, error: Exception, source: Union[PyglsError, JsonRpcException]): + pass + + Workspace ~~~~~~~~~ diff --git a/pygls/exceptions.py b/pygls/exceptions.py index 00611d94..0d3245e2 100644 --- a/pygls/exceptions.py +++ b/pygls/exceptions.py @@ -188,6 +188,14 @@ def __repr__(self): return f'Feature "{self.feature_name}" is already registered.' +class FeatureRequestError(PyglsError): + pass + + +class FeatureNotificationError(PyglsError): + pass + + class MethodTypeNotRegisteredError(PyglsError): def __init__(self, name): diff --git a/pygls/protocol.py b/pygls/protocol.py index 38d49b0b..ebbe72a8 100644 --- a/pygls/protocol.py +++ b/pygls/protocol.py @@ -21,7 +21,6 @@ import logging import re import sys -import traceback import uuid from collections import namedtuple from concurrent.futures import Future @@ -33,7 +32,8 @@ from pygls.constants import ATTR_FEATURE_TYPE from pygls.exceptions import (JsonRpcException, JsonRpcInternalError, JsonRpcInvalidParams, JsonRpcMethodNotFound, JsonRpcRequestCancelled, - MethodTypeNotRegisteredError) + MethodTypeNotRegisteredError, FeatureNotificationError, + FeatureRequestError) from pygls.feature_manager import (FeatureManager, assign_help_attrs, get_help_attrs, is_thread_function) from pygls.lsp import (JsonRPCNotification, JsonRPCRequestMessage, JsonRPCResponseMessage, @@ -316,8 +316,14 @@ def _handle_notification(self, method_name, params): self._execute_notification(handler, params) except (KeyError, JsonRpcMethodNotFound): logger.warning('Ignoring notification for unknown method "%s"', method_name) - except Exception: - logger.exception('Failed to handle notification "%s": %s', method_name, params) + except Exception as error: + logger.exception( + 'Failed to handle notification "%s": %s', + method_name, + params, + exc_info=True + ) + self._server._report_server_error(error, FeatureNotificationError) def _handle_request(self, msg_id, method_name, params): """Handles a request from the client.""" @@ -330,13 +336,27 @@ def _handle_request(self, msg_id, method_name, params): else: self._execute_request(msg_id, handler, params) - except JsonRpcException as e: - logger.exception('Failed to handle request %s %s %s', msg_id, method_name, params) - self._send_response(msg_id, None, e.to_dict()) - except Exception: - logger.exception('Failed to handle request %s %s %s', msg_id, method_name, params) + except JsonRpcException as error: + logger.exception( + 'Failed to handle request %s %s %s', + msg_id, + method_name, + params, + exc_info=True + ) + self._send_response(msg_id, None, error.to_dict()) + self._server._report_server_error(error, FeatureRequestError) + except Exception as error: + logger.exception( + 'Failed to handle request %s %s %s', + msg_id, + method_name, + params, + exc_info=True + ) err = JsonRpcInternalError.of(sys.exc_info()).to_dict() self._send_response(msg_id, None, err) + self._server._report_server_error(error, FeatureRequestError) def _handle_response(self, msg_id, result=None, error=None): """Handles a response from the client.""" @@ -355,6 +375,7 @@ def _handle_response(self, msg_id, result=None, error=None): def _procedure_handler(self, message): """Delegates message to handlers depending on message type.""" + if message.jsonrpc != JsonRPCProtocol.VERSION: logger.warning('Unknown message "%s"', message) return @@ -377,7 +398,6 @@ def _send_data(self, data): """Sends data to the client.""" if not data: return - try: body = data.json(by_alias=True, exclude_unset=True, encoder=default_serializer) logger.info('Sending data: %s', body) @@ -392,8 +412,9 @@ def _send_data(self, data): self.transport.write(header + body) else: self.transport.write(body.decode('utf-8')) - except Exception: - logger.error(traceback.format_exc()) + except Exception as error: + logger.exception("Error sending data", exc_info=True) + self._server._report_server_error(error, JsonRpcInternalError) def _send_response(self, msg_id, result=None, error=None): """Sends a JSON RPC response to the client. @@ -427,6 +448,13 @@ def connection_made(self, transport: asyncio.BaseTransport): self.transport = transport def data_received(self, data: bytes): + try: + self._data_received(data) + except Exception as error: + logger.exception("Error receiving data", exc_info=True) + self._server._report_server_error(error, JsonRpcInternalError) + + def _data_received(self, data: bytes): """Method from base class, called when server receives the data""" logger.debug('Received %r', data) diff --git a/pygls/server.py b/pygls/server.py index da9c4892..06195c54 100644 --- a/pygls/server.py +++ b/pygls/server.py @@ -21,9 +21,10 @@ import sys from concurrent.futures import Future, ThreadPoolExecutor from threading import Event -from typing import Any, Callable, List, Optional, TypeVar +from typing import Any, Callable, List, Optional, TypeVar, Union from pygls import IS_WIN, IS_PYODIDE +from pygls.exceptions import PyglsError, JsonRpcException, FeatureRequestError from pygls.lsp.types import (ApplyWorkspaceEditResponse, ClientCapabilities, ConfigCallbackType, ConfigurationParams, Diagnostic, MessageType, RegistrationParams, ServerCapabilities, TextDocumentSyncKind, UnregistrationParams, @@ -325,6 +326,12 @@ class LanguageServer(Server): `ThreadPoolExecutor` """ + default_error_message = "Unexpected error in LSP server, see server's logs for details" + """ + The default error message sent to the user's editor when this server encounters an uncaught + exception. + """ + def __init__(self, loop=None, protocol_cls=LanguageServerProtocol, max_workers: int = 2): if not issubclass(protocol_cls, LanguageServerProtocol): raise TypeError('Protocol class should be subclass of LanguageServerProtocol') @@ -426,6 +433,37 @@ def show_message_log(self, message, msg_type=MessageType.Log) -> None: """Sends message to the client's output channel.""" self.lsp.show_message_log(message, msg_type) + def _report_server_error(self, error: Exception, source: Union[PyglsError, JsonRpcException]): + # Prevent recursive error reporting + try: + self.report_server_error(error, source) + except Exception: + logger.warning("Failed to report error to client") + + def report_server_error(self, error: Exception, source: Union[PyglsError, JsonRpcException]): + """ + Sends error to the client for displaying. + + By default this fucntion does not handle LSP request errors. This is because LSP requests + require direct responses and so already have a mechanism for including unexpected errors + in the response body. + + All other errors are "out of band" in the sense that the client isn't explicitly waiting + for them. For example diagnostics are returned as notifications, not responses to requests, + and so can seemingly be sent at random. Also for example consider JSON RPC serialization + and deserialization, if a payload cannot be parsed then the whole request/response cycle + cannot be completed and so one of these "out of band" error messages is sent. + + These "out of band" error messages are not a requirement of the LSP spec. Pygls simply + offers this behaviour as a recommended default. It is perfectly reasonble to override this + default. + """ + + if source == FeatureRequestError: + return + + self.show_message(self.default_error_message, msg_type=MessageType.Error) + def thread(self) -> Callable[[F], F]: """Decorator that mark function to execute it in a thread.""" return self.lsp.thread() diff --git a/tests/ls_setup.py b/tests/ls_setup.py index e5d9bbfe..590e98ee 100644 --- a/tests/ls_setup.py +++ b/tests/ls_setup.py @@ -69,10 +69,10 @@ class PyodideClientServer: """Implementation of the `client_server` fixture for use in a pyodide environment.""" - def __init__(self): + def __init__(self, LS=LanguageServer): - self.server = LanguageServer() - self.client = LanguageServer() + self.server = LS() + self.client = LS() self.server.lsp.connection_made(PyodideTestTransportAdapter(self.client)) self.server.lsp._send_only_body = True @@ -112,14 +112,14 @@ def __iter__(self): class NativeClientServer: - def __init__(self): + def __init__(self, LS=LanguageServer): # Client to Server pipe csr, csw = os.pipe() # Server to client pipe scr, scw = os.pipe() # Setup Server - self.server = LanguageServer() + self.server = LS() self.server_thread = threading.Thread( target=self.server.start_io, args=(os.fdopen(csr, "rb"), os.fdopen(scw, "wb")), @@ -127,7 +127,7 @@ def __init__(self): self.server_thread.daemon = True # Setup client - self.client = LanguageServer(asyncio.new_event_loop()) + self.client = LS(asyncio.new_event_loop()) self.client_thread = threading.Thread( target=self.client.start_io, args=(os.fdopen(scr, "rb"), os.fdopen(csw, "wb")), diff --git a/tests/lsp/test_errors.py b/tests/lsp/test_errors.py new file mode 100644 index 00000000..4f5f7484 --- /dev/null +++ b/tests/lsp/test_errors.py @@ -0,0 +1,136 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ + +from typing import Any, List, Union +import time + +import pytest + +from pygls.exceptions import JsonRpcInternalError, PyglsError, JsonRpcException +from pygls.lsp.methods import WINDOW_SHOW_MESSAGE +from pygls.server import LanguageServer +from pygls.lsp.types import MessageType + +from ..conftest import ClientServer + +ERROR_TRIGGER = "test/triggerError" +ERROR_MESSAGE = "Testing errors" + + +class CustomLanguageServerSafe(LanguageServer): + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): + pass + + +class CustomLanguageServerPotentialRecursion(LanguageServer): + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): + raise Exception() + + +class CustomLanguageServerSendAll(LanguageServer): + def report_server_error( + self, error: Exception, source: Union[PyglsError, JsonRpcException] + ): + self.show_message(self.default_error_message, msg_type=MessageType.Error) + + +class ConfiguredLS(ClientServer): + def __init__(self, LS=LanguageServer): + super().__init__(LS) + self.init() + + def init(self): + self.client.messages: List[str] = [] + + @self.server.feature(ERROR_TRIGGER) + def f1(params: Any): + raise Exception(ERROR_MESSAGE) + + @self.client.feature(WINDOW_SHOW_MESSAGE) + def f2(params: Any): + self.client.messages.append(params.message) + + +class CustomConfiguredLSSafe(ConfiguredLS): + def __init__(self): + super().__init__(CustomLanguageServerSafe) + + +class CustomConfiguredLSPotentialRecusrion(ConfiguredLS): + def __init__(self): + super().__init__(CustomLanguageServerPotentialRecursion) + + +class CustomConfiguredLSSendAll(ConfiguredLS): + def __init__(self): + super().__init__(CustomLanguageServerSendAll) + + +@ConfiguredLS.decorate() +def test_request_error_reporting_default(client_server): + client, _ = client_server + assert len(client.messages) == 0 + + with pytest.raises(JsonRpcInternalError, match=ERROR_MESSAGE): + client.lsp.send_request(ERROR_TRIGGER).result() + + time.sleep(0.1) + assert len(client.messages) == 0 + + +@CustomConfiguredLSSendAll.decorate() +def test_request_error_reporting_override(client_server): + client, _ = client_server + assert len(client.messages) == 0 + + with pytest.raises(JsonRpcInternalError, match=ERROR_MESSAGE): + client.lsp.send_request(ERROR_TRIGGER).result() + + time.sleep(0.1) + assert len(client.messages) == 1 + + +@ConfiguredLS.decorate() +def test_notification_error_reporting(client_server): + client, _ = client_server + client.lsp.notify(ERROR_TRIGGER) + time.sleep(0.1) + + assert len(client.messages) == 1 + assert client.messages[0] == LanguageServer.default_error_message + + +@CustomConfiguredLSSafe.decorate() +def test_overriding_error_reporting(client_server): + client, _ = client_server + client.lsp.notify(ERROR_TRIGGER) + time.sleep(0.1) + + assert len(client.messages) == 0 + + +@CustomConfiguredLSPotentialRecusrion.decorate() +def test_overriding_error_reporting_with_potential_recursion(client_server): + client, _ = client_server + client.lsp.notify(ERROR_TRIGGER) + time.sleep(0.1) + + assert len(client.messages) == 0