diff --git a/src/jsonrpc.rs b/src/jsonrpc.rs index 831a512d..cace225d 100644 --- a/src/jsonrpc.rs +++ b/src/jsonrpc.rs @@ -115,15 +115,6 @@ pub enum Incoming { Request(ServerRequest), /// Response to a server-to-client request. Response(Response), - /// An invalid JSON-RPC request. - Invalid { - /// Request ID, if known. - #[serde(default)] - id: Option, - /// Method name, if known. - #[serde(default)] - method: Option, - }, } /// A server-to-client LSP request. @@ -276,4 +267,19 @@ mod tests { let from_value: Outgoing = serde_json::from_value(v).unwrap(); assert_eq!(from_str, from_value); } + + #[test] + fn parses_invalid_server_request() { + let unknown_method = json!({"jsonrpc":"2.0","method":"foo"}); + let _: Incoming = serde_json::from_value(unknown_method).unwrap(); + + let unknown_method_with_id = json!({"jsonrpc":"2.0","method":"foo","id":1}); + let _: Incoming = serde_json::from_value(unknown_method_with_id).unwrap(); + + let missing_method = json!({"jsonrpc":"2.0"}); + let _: Incoming = serde_json::from_value(missing_method).unwrap(); + + let missing_method_with_id = json!({"jsonrpc":"2.0","id":1}); + let _: Incoming = serde_json::from_value(missing_method_with_id).unwrap(); + } } diff --git a/src/service.rs b/src/service.rs index ecd02858..dcae9b5a 100644 --- a/src/service.rs +++ b/src/service.rs @@ -10,11 +10,11 @@ use std::task::{Context, Poll}; use futures::channel::mpsc::{self, Receiver}; use futures::stream::FusedStream; use futures::{future, FutureExt, Stream}; -use log::{error, trace}; +use log::trace; use tower_service::Service; use super::client::Client; -use super::jsonrpc::{self, ClientRequests, Incoming, Outgoing, Response, ServerRequests}; +use super::jsonrpc::{ClientRequests, Incoming, Outgoing, ServerRequests}; use super::{generated_impl, LanguageServer, ServerState, State}; /// Error that occurs when attempting to call the language server after it has already exited. @@ -130,18 +130,6 @@ impl Service for LspService { self.pending_client.insert(res); future::ok(None).boxed() } - Incoming::Invalid { id, method } => match (id, method) { - (None, Some(method)) if method.starts_with("$/") => future::ok(None).boxed(), - (id, Some(method)) => { - error!("method {:?} not found", method); - let res = Response::error(id, jsonrpc::Error::method_not_found()); - future::ok(Some(Outgoing::Response(res))).boxed() - } - (id, None) => { - let res = Response::error(id, jsonrpc::Error::invalid_request()); - future::ok(Some(Outgoing::Response(res))).boxed() - } - }, } } } diff --git a/tower-lsp-macros/src/lib.rs b/tower-lsp-macros/src/lib.rs index 15ac2d87..100ed2d1 100644 --- a/tower-lsp-macros/src/lib.rs +++ b/tower-lsp-macros/src/lib.rs @@ -149,17 +149,17 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma Ok(Some(Outgoing::Response(res))) }) - }, + } (ServerMethod::#var_name { params: Invalid(e), id }, State::Uninitialized) => { error!("invalid parameters for {:?} request", #rpc_name); let res = Response::error(Some(id), Error::invalid_params(e)); future::ok(Some(Outgoing::Response(res))).boxed() - }, + } (ServerMethod::#var_name { id, .. }, State::Initializing) => { warn!("received duplicate `initialize` request, ignoring"); let res = Response::error(Some(id), Error::invalid_request()); future::ok(Some(Outgoing::Response(res))).boxed() - }, + } }, (true, false) if rpc_name == "shutdown" => quote! { (ServerMethod::#var_name { id }, State::Initialized) => { @@ -169,7 +169,7 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma .execute(id, async move { server.#handler().await }) .map(|v| Ok(Some(Outgoing::Response(v)))) .boxed() - }, + } }, (true, true) => quote! { (ServerMethod::#var_name { params: Valid(p), id }, State::Initialized) => { @@ -177,12 +177,12 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma .execute(id, async move { server.#handler(p).await }) .map(|v| Ok(Some(Outgoing::Response(v)))) .boxed() - }, + } (ServerMethod::#var_name { params: Invalid(e), id }, State::Initialized) => { error!("invalid parameters for {:?} request", #rpc_name); let res = Response::error(Some(id), Error::invalid_params(e)); future::ok(Some(Outgoing::Response(res))).boxed() - }, + } }, (true, false) => quote! { (ServerMethod::#var_name { id }, State::Initialized) => { @@ -190,21 +190,21 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma .execute(id, async move { server.#handler().await }) .map(|v| Ok(Some(Outgoing::Response(v)))) .boxed() - }, + } }, (false, true) => quote! { (ServerMethod::#var_name { params: Valid(p) }, State::Initialized) => { Box::pin(async move { server.#handler(p).await; Ok(None) }) - }, + } (ServerMethod::#var_name { .. }, State::Initialized) => { warn!("invalid parameters for {:?} notification", #rpc_name); future::ok(None).boxed() - }, + } }, (false, false) => quote! { (ServerMethod::#var_name, State::Initialized) => { Box::pin(async move { server.#handler().await; Ok(None) }) - }, + } }, } }) @@ -236,7 +236,15 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma pub struct ServerRequest { jsonrpc: Version, #[serde(flatten)] - inner: ServerMethod, + kind: RequestKind, + } + + #[derive(Clone, Debug, PartialEq, serde::Deserialize)] + #[cfg_attr(test, derive(serde::Serialize))] + #[serde(untagged)] + enum RequestKind { + Valid(ServerMethod), + Invalid { id: Option, method: Option }, } #[derive(Clone, Debug, PartialEq, serde::Deserialize)] @@ -251,7 +259,6 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma } impl ServerMethod { - #[inline] fn id(&self) -> Option<&Id> { match *self { #id_match_arms @@ -285,34 +292,53 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma server: T, state: &Arc, pending: &ServerRequests, - incoming: ServerRequest, + request: ServerRequest, ) -> Pin, ExitedError>> + Send>> { use Params::*; - match (incoming.inner, state.get()) { + + let method = match request.kind { + RequestKind::Valid(method) => method, + RequestKind::Invalid { id: Some(id), method: Some(m) } => { + error!("method {:?} not found", m); + let res = Response::error(Some(id), Error::method_not_found()); + return future::ok(Some(Outgoing::Response(res))).boxed(); + } + RequestKind::Invalid { id: Some(id), .. } => { + let res = Response::error(Some(id), Error::invalid_request()); + return future::ok(Some(Outgoing::Response(res))).boxed(); + } + RequestKind::Invalid { id: None, method: Some(m) } if !m.starts_with("$/") => { + error!("method {:?} not found", m); + return future::ok(None).boxed(); + } + RequestKind::Invalid { id: None, .. } => return future::ok(None).boxed(), + }; + + match (method, state.get()) { #route_match_arms (ServerMethod::CancelRequest { id }, State::Initialized) => { pending.cancel(&id); future::ok(None).boxed() - }, + } (ServerMethod::Exit, _) => { info!("exit notification received, stopping"); state.set(State::Exited); pending.cancel_all(); future::ok(None).boxed() - }, + } (other, State::Uninitialized) => Box::pin(match other.id().cloned() { None => future::ok(None), Some(id) => { let res = Response::error(Some(id), not_initialized_error()); future::ok(Some(Outgoing::Response(res))) - }, + } }), (other, _) => Box::pin(match other.id().cloned() { None => future::ok(None), Some(id) => { let res = Response::error(Some(id), Error::invalid_request()); future::ok(Some(Outgoing::Response(res))) - }, + } }), } }