diff --git a/pygls/exceptions.py b/pygls/exceptions.py index 00611d94..d1be92f5 100644 --- a/pygls/exceptions.py +++ b/pygls/exceptions.py @@ -40,11 +40,12 @@ def __hash__(self): return hash((self.code, self.message)) @staticmethod - def from_dict(error): + def from_error(error): for exc_class in _EXCEPTIONS: - if exc_class.supports_code(error['code']): - return exc_class(**error) - return JsonRpcException(**error) + if exc_class.supports_code(error.code): + return exc_class(code=error.code, message=error.message, data=error.data) + + return JsonRpcException(code=error.code, message=error.message, data=error.data) @classmethod def supports_code(cls, code): diff --git a/pygls/lsp/__init__.py b/pygls/lsp/__init__.py index 5a810c7b..0b92ec3b 100644 --- a/pygls/lsp/__init__.py +++ b/pygls/lsp/__init__.py @@ -36,32 +36,6 @@ TEXT_DOCUMENT_SEMANTIC_TOKENS_RANGE: Union[SemanticTokensLegend, SemanticTokensRegistrationOptions], } -@attrs.define -class JsonRPCNotification: - """A class that represents json rpc notification message.""" - - jsonrpc: str - method: str - params: Any - -@attrs.define -class JsonRPCRequestMessage: - """A class that represents json rpc request message.""" - - jsonrpc: str - id: Union[int, str] - method: str - params: Any - - -@attrs.define -class JsonRPCResponseMessage: - """A class that represents json rpc response message.""" - - jsonrpc: str - id: Union[int, str] - result: Union[Any, None] = attrs.field(default=None) - def get_method_registration_options_type( method_name, lsp_methods_map=METHOD_TO_TYPES diff --git a/pygls/protocol.py b/pygls/protocol.py index 38d49b0b..543ac69d 100644 --- a/pygls/protocol.py +++ b/pygls/protocol.py @@ -25,22 +25,23 @@ import uuid from collections import namedtuple from concurrent.futures import Future -from functools import partial +from functools import lru_cache, partial from itertools import zip_longest -from typing import Callable, List, Optional +from typing import Any, Dict, Callable, List, Optional, Type, Union + +import attrs +from cattrs.errors import ClassValidationError +from lsprotocol import converters from pygls.capabilities import ServerCapabilitiesBuilder from pygls.constants import ATTR_FEATURE_TYPE from pygls.exceptions import (JsonRpcException, JsonRpcInternalError, JsonRpcInvalidParams, - JsonRpcMethodNotFound, JsonRpcRequestCancelled, - MethodTypeNotRegisteredError) -from pygls.feature_manager import (FeatureManager, assign_help_attrs, get_help_attrs, - is_thread_function) -from pygls.lsp import (JsonRPCNotification, JsonRPCRequestMessage, JsonRPCResponseMessage, - get_method_params_type, get_method_return_type, is_instance) -from pygls.lsp.methods import (CANCEL_REQUEST, CLIENT_REGISTER_CAPABILITY, + JsonRpcMethodNotFound, JsonRpcRequestCancelled) +from pygls.feature_manager import (FeatureManager, assign_help_attrs, is_thread_function) +from pygls.lsp import (ConfigCallbackType, ShowDocumentCallbackType) +from lsprotocol.types import (CANCEL_REQUEST, CLIENT_REGISTER_CAPABILITY, CLIENT_UNREGISTER_CAPABILITY, EXIT, INITIALIZE, INITIALIZED, - LOG_TRACE_NOTIFICATION, SET_TRACE_NOTIFICATION, SHUTDOWN, + METHOD_TO_TYPES, LOG_TRACE, SET_TRACE, SHUTDOWN, TEXT_DOCUMENT_DID_CHANGE, TEXT_DOCUMENT_DID_CLOSE, TEXT_DOCUMENT_DID_OPEN, TEXT_DOCUMENT_PUBLISH_DIAGNOSTICS, WINDOW_LOG_MESSAGE, WINDOW_SHOW_DOCUMENT, WINDOW_SHOW_MESSAGE, @@ -62,7 +63,7 @@ from pygls.workspace import Workspace logger = logging.getLogger(__name__) - +converter = converters.get_converter() def call_user_feature(base_func, method_name): """Wraps generic LSP features and calls user registered feature @@ -86,8 +87,15 @@ def decorator(self, *args, **kwargs): return decorator -def dict_to_object(**d): +def dict_to_object(d: Any): """Create nested objects (namedtuple) from dict.""" + + if d is None: + return None + + if not isinstance(d, dict): + return d + type_name = d.pop('type_name', 'Object') return json.loads( json.dumps(d), @@ -95,73 +103,48 @@ def dict_to_object(**d): ) -def default_serializer(o): - """JSON serializer for complex objects that do not extend pydantic BaseModel class.""" - if isinstance(o, enum.Enum): - return o.value - return o.__dict__ - - -def deserialize_command(params): - """Function used to deserialize command arguments to a specific class - or a namedtuple.""" - # TODO: Register/Look up custom command arguments' types - # Currently command parameters are type of 'any', but eventually we would - # want to register an argument type of our custom command and to - # deserialize it properly. - temp_obj = dict_to_object(**params, type_name='CommandParams') - - params['arguments'] = getattr(temp_obj, 'arguments', None) - return params - +@attrs.define +class JsonRPCNotification: + """A class that represents a generic json rpc notification message. + Used as a fallback for unknown types. + """ -def deserialize_params(data, get_params_type): - """Function used to deserialize params to a specific class.""" - try: - method = data['method'] - params = data['params'] + method: str + jsonrpc: str + params: Any = attrs.field(converter=dict_to_object) - if not isinstance(params, dict): - return data +@attrs.define +class JsonRPCRequestMessage: + """A class that represents a generic json rpc request message. + Used as a fallback for unknown types. + """ - try: - params_type = get_params_type(method) - if params_type is None: - params_type = dict_to_object - elif params_type.__name__ == ExecuteCommandParams.__name__: - params = deserialize_command(params) + id: Union[int, str] + method: str + jsonrpc: str + params: Any = attrs.field(converter=dict_to_object) - except MethodTypeNotRegisteredError: - params_type = dict_to_object - try: - data['params'] = params_type(**params) - except TypeError: - raise ValueError( - f'Could not instantiate "{params_type.__name__}" from params: {params}') - except KeyError: - pass +@attrs.define +class JsonRPCResponseMessage: + """A class that represents a generic json rpc response message. + Used as a fallback for unknown types. + """ - return data + id: Union[int, str] + jsonrpc: str + result: Any = attrs.field(converter=dict_to_object) +def default_serializer(o): + """JSON serializer for complex objects that do not extend pydantic BaseModel class.""" -def deserialize_message(data, get_params_type=get_method_params_type): - """Function used to deserialize data received from client.""" - if 'jsonrpc' in data: - try: - deserialize_params(data, get_params_type) - except ValueError: - raise JsonRpcInvalidParams() + if hasattr(o, '__attrs_attrs__'): + return converter.unstructure(o) - if 'id' in data: - if 'method' in data: - return JsonRPCRequestMessage(**data) - else: - return JsonRPCResponseMessage(**data) - else: - return JsonRPCNotification(**data) + if isinstance(o, enum.Enum): + return o.value - return data + return o.__dict__ class JsonRPCProtocol(asyncio.Protocol): @@ -189,8 +172,9 @@ def __init__(self, server): self._server = server self._shutdown = False - self._client_request_futures = {} - self._server_request_futures = {} + # Book keeping for in-flight requests + self._request_futures = {} + self._result_types = {} self.fm = FeatureManager(server) self.transport = None @@ -201,16 +185,6 @@ def __init__(self, server): def __call__(self): return self - def _check_ret_type_and_send_response(self, method_name, method_type, msg_id, result): - """Check if registered feature returns appropriate result type.""" - if method_type == ATTR_FEATURE_TYPE: - return_type = get_method_return_type(method_name) - if not is_instance(result, return_type): - error = JsonRpcInternalError().to_dict() - self._send_response(msg_id, error=error) - - self._send_response(msg_id, result=result) - def _execute_notification(self, handler, *params): """Executes notification message handler.""" if asyncio.iscoroutinefunction(handler): @@ -238,39 +212,34 @@ def _execute_notification_callback(self, future): def _execute_request(self, msg_id, handler, params): """Executes request message handler.""" - method_name, method_type = get_help_attrs(handler) if asyncio.iscoroutinefunction(handler): future = asyncio.ensure_future(handler(params)) - self._client_request_futures[msg_id] = future - future.add_done_callback(partial(self._execute_request_callback, - method_name, method_type, msg_id)) + self._request_futures[msg_id] = future + future.add_done_callback(partial(self._execute_request_callback, msg_id)) else: # Can't be canceled if is_thread_function(handler): self._server.thread_pool.apply_async( handler, (params, ), callback=partial( - self._check_ret_type_and_send_response, - method_name, method_type, msg_id, + self._send_response, msg_id, ), error_callback=partial(self._execute_request_err_callback, msg_id)) else: - self._check_ret_type_and_send_response( - method_name, method_type, msg_id, handler(params)) + self._send_response(msg_id, handler(params)) - def _execute_request_callback(self, method_name, method_type, msg_id, future): + def _execute_request_callback(self, msg_id, future): """Success callback used for coroutine request message.""" try: if not future.cancelled(): - self._check_ret_type_and_send_response( - method_name, method_type, msg_id, result=future.result()) + self._send_response(msg_id, result=future.result()) else: self._send_response( msg_id, error=JsonRpcRequestCancelled(f'Request with id "{msg_id}" is canceled') ) - self._client_request_futures.pop(msg_id, None) + self._request_futures.pop(msg_id, None) except Exception: error = JsonRpcInternalError.of(sys.exc_info()).to_dict() logger.exception('Exception occurred for message "%s": %s', msg_id, error) @@ -295,7 +264,7 @@ def _get_handler(self, feature_name): def _handle_cancel_notification(self, msg_id): """Handles a cancel notification from the client.""" - future = self._client_request_futures.pop(msg_id, None) + future = self._client_requests.pop(msg_id, None) if not future: logger.warning('Cancel notification for unknown message id "%s"', msg_id) @@ -340,7 +309,7 @@ def _handle_request(self, msg_id, method_name, params): def _handle_response(self, msg_id, result=None, error=None): """Handles a response from the client.""" - future = self._server_request_futures.pop(msg_id, None) + future = self._request_futures.pop(msg_id, None) if not future: logger.warning('Received response to unknown message id "%s"', msg_id) @@ -348,11 +317,47 @@ def _handle_response(self, msg_id, result=None, error=None): if error is not None: logger.debug('Received error response to message "%s": %s', msg_id, error) - future.set_exception(JsonRpcException.from_dict(error)) + future.set_exception(JsonRpcException.from_error(error)) else: logger.debug('Received result for message "%s": %s', msg_id, result) future.set_result(result) + def _deserialize_message(self, data): + """Function used to deserialize data recevied from the client.""" + + if 'jsonrpc' not in data: + return data + + try: + if 'id' in data: + if 'error' in data: + return converter.structure(data, ResponseErrorMessage) + elif 'method' in data: + request_type = ( + self.get_message_type(data['method']) or JsonRPCRequestMessage + ) + return converter.structure(data, request_type) + else: + response_type = ( + self._result_types.pop(data['id']) or JsonRPCResponseMessage + ) + return converter.structure(data, response_type) + + else: + method = data.get('method', '') + notification_type = self.get_message_type(method) or JsonRPCNotification + return converter.structure(data, notification_type) + + except ClassValidationError as exc: + logger.error("Unable to deserialize message\n%s", traceback.format_exc()) + raise JsonRpcInvalidParams() from exc + + except Exception as exc: + logger.error("Unable to deserialize message\n%s", traceback.format_exc()) + raise JsonRpcInternalError() from exc + + return data + def _procedure_handler(self, message): """Delegates message to handlers depending on message type.""" if message.jsonrpc != JsonRPCProtocol.VERSION: @@ -363,15 +368,20 @@ def _procedure_handler(self, message): logger.warning('Server shutting down. No more requests!') return - if isinstance(message, JsonRPCNotification): - logger.debug('Notification message received.') - self._handle_notification(message.method, message.params) - elif isinstance(message, JsonRPCResponseMessage): - logger.debug('Response message received.') - self._handle_response(message.id, message.result, message.error) - elif isinstance(message, JsonRPCRequestMessage): - logger.debug('Request message received.') - self._handle_request(message.id, message.method, message.params) + if hasattr(message, 'method'): + if hasattr(message, 'id'): + logger.debug('Request message received.') + self._handle_request(message.id, message.method, message.params) + else: + logger.debug('Notification message received.') + self._handle_notification(message.method, message.params) + else: + if hasattr(message, 'error'): + logger.debug('Error message received.') + self._handle_response(message.id, None, message.error) + else: + logger.debug('Response message received.') + self._handle_response(message.id, message.result) def _send_data(self, data): """Sends data to the client.""" @@ -379,7 +389,7 @@ def _send_data(self, data): return try: - body = data.json(by_alias=True, exclude_unset=True, encoder=default_serializer) + body = json.dumps(data, default=default_serializer) logger.info('Sending data: %s', body) body = body.encode(self.CHARSET) @@ -403,15 +413,15 @@ def _send_response(self, msg_id, result=None, error=None): result(any): Result returned by handler error(any): Error returned by handler """ - response = JsonRPCResponseMessage(id=msg_id, - jsonrpc=JsonRPCProtocol.VERSION, - result=result, - error=error) - if error is None: - del response.error + if error is not None: + response = ResponseErrorMessage(id=msg_id, error=error) + else: - del response.result + response_type = self._result_types.pop(msg_id, JsonRPCResponseMessage) + response = response_type( + id=msg_id, result=result, jsonrpc=JsonRPCProtocol.VERSION + ) self._send_data(response) @@ -454,19 +464,29 @@ def data_received(self, data: bytes): # Parse the body self._procedure_handler( json.loads(body.decode(self.CHARSET), - object_hook=deserialize_message)) + object_hook=self._deserialize_message)) + + def get_message_type(self, method: str) -> Optional[Type]: + """Return the type definition of the message associated with the given method.""" + return None + + def get_result_type(self, method: str) -> Optional[Type]: + """Return the type definition of the result associated with the given method.""" + return None def notify(self, method: str, params=None): """Sends a JSON RPC notification to the client.""" - logger.debug('Sending notification: "%s" %s', method, params) - request = JsonRPCNotification( - jsonrpc=JsonRPCProtocol.VERSION, + logger.debug("Sending notification: '%s' %s", method, params) + + notification_type = self.get_message_type(method) or JsonRPCNotification + notification = notification_type( method=method, - params=params + params=params, + jsonrpc=JsonRPCProtocol.VERSION ) - self._send_data(request) + self._send_data(notification) def send_request(self, method, params=None, callback=None): """Sends a JSON RPC request to the client. @@ -478,14 +498,16 @@ def send_request(self, method, params=None, callback=None): Returns: Future that will be resolved once a response has been received """ + msg_id = str(uuid.uuid4()) + request_type = self.get_message_type(method) or JsonRPCRequestMessage logger.debug('Sending request with id "%s": %s %s', msg_id, method, params) - request = JsonRPCRequestMessage( + request = request_type( id=msg_id, - jsonrpc=JsonRPCProtocol.VERSION, method=method, - params=params + params=params, + jsonrpc=JsonRPCProtocol.VERSION, ) future = Future() @@ -497,7 +519,9 @@ def wrapper(future: Future): callback(result) future.add_done_callback(wrapper) - self._server_request_futures[msg_id] = future + self._request_futures[msg_id] = future + self._result_types[msg_id] = self.get_result_type(method) + self._send_data(request) return future @@ -573,7 +597,16 @@ def _register_builtin_features(self): if callable(attr) and hasattr(attr, 'method_name'): self.fm.add_builtin_feature(attr.method_name, attr) - def apply_edit(self, edit: WorkspaceEdit, label: str = None) -> ApplyWorkspaceEditResponse: + @lru_cache() + def get_message_type(self, method: str) -> Optional[Type]: + """Return LSP type definitions, as provided by `lsprotocol`""" + return METHOD_TO_TYPES.get(method, (None,))[0] + + @lru_cache() + def get_result_type(self, method: str) -> Optional[Type]: + return METHOD_TO_TYPES.get(method, (None, None))[1] + + def apply_edit(self, edit: WorkspaceEdit, label: str = None) -> WorkspaceApplyEditResponse: """Sends apply edit request to the client.""" return self.send_request(WORKSPACE_APPLY_EDIT, ApplyWorkspaceEditParams(edit=edit, label=label)) @@ -603,7 +636,10 @@ def lsp_initialize(self, params: InitializeParams) -> InitializeResult: list(self.fm.commands.keys()), self._server.sync_kind, ).build() - logger.debug('Server capabilities: %s', self.server_capabilities.dict()) + logger.debug( + 'Server capabilities: %s', + json.dumps(self.server_capabilities, default=default_serializer) + ) root_path = params.root_path root_uri = params.root_uri or from_fs_path(root_path) @@ -624,10 +660,7 @@ def lsp_initialized(self, *args) -> None: @lsp_method(SHUTDOWN) def lsp_shutdown(self, *args) -> None: """Request from client which asks server to shutdown.""" - for future in self._client_request_futures.values(): - future.cancel() - - for future in self._server_request_futures.values(): + for future in self._request_futures.values(): future.cancel() self._shutdown = True diff --git a/pygls/server.py b/pygls/server.py index da9c4892..127f7684 100644 --- a/pygls/server.py +++ b/pygls/server.py @@ -30,7 +30,7 @@ WorkspaceEdit) from pygls.lsp.types.window import ShowDocumentCallbackType, ShowDocumentParams from pygls.progress import Progress -from pygls.protocol import LanguageServerProtocol, deserialize_message +from pygls.protocol import LanguageServerProtocol from pygls.workspace import Workspace if not IS_PYODIDE: @@ -278,7 +278,7 @@ async def connection_made(websocket, _): self.lsp.transport = WebSocketTransportAdapter(websocket, self.loop) async for message in websocket: self.lsp._procedure_handler( - json.loads(message, object_hook=deserialize_message) + json.loads(message, object_hook=self.lsp._deserialize_message) ) start_server = websockets.serve(connection_made, host, port)