diff --git a/Cargo.lock b/Cargo.lock index acc32481df0..48b1d834967 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -286,34 +286,12 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "async-trait" -version = "0.1.81" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" -[[package]] -name = "auto_impl" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c87f3f15e7794432337fc718554eaa4dc8f04c9677a950ffe366f20a162ae42" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "autocfg" version = "1.3.0" @@ -665,21 +643,25 @@ dependencies = [ "cairo-lang-test-plugin", "cairo-lang-test-utils", "cairo-lang-utils", + "crossbeam", "futures", "indent", "indoc", "itertools 0.12.1", + "jod-thread", + "libc", + "lsp-server", + "lsp-types", "pathdiff", "pretty_assertions", "rust-analyzer-salsa", + "rustc-hash", "scarb-metadata", "serde", "serde_json", "smol_str", "tempfile", "test-log", - "tokio", - "tower-lsp", "tower-service", "tracing", "tracing-chrome", @@ -1393,6 +1375,28 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -1412,6 +1416,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -1445,19 +1458,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "dashmap" -version = "5.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" -dependencies = [ - "cfg-if", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "deranged" version = "0.3.11" @@ -2263,6 +2263,12 @@ dependencies = [ "libc", ] +[[package]] +name = "jod-thread" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b23360e99b8717f20aaa4598f5a6541efbe30630039fbc7706cf954a87947ae" + [[package]] name = "js-sys" version = "0.3.69" @@ -2390,11 +2396,23 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "lsp-server" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "550446e84739dcaf6d48a4a093973850669e13e8a34d8f8d64851041be267cd9" +dependencies = [ + "crossbeam-channel", + "log", + "serde", + "serde_json", +] + [[package]] name = "lsp-types" -version = "0.94.1" +version = "0.95.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1" +checksum = "158c1911354ef73e8fe42da6b10c0484cb65c7f1007f28022e847706c1ab6984" dependencies = [ "bitflags 1.3.2", "serde", @@ -4161,40 +4179,6 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" -[[package]] -name = "tower-lsp" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508" -dependencies = [ - "async-trait", - "auto_impl", - "bytes", - "dashmap", - "futures", - "httparse", - "lsp-types", - "memchr", - "serde", - "serde_json", - "tokio", - "tokio-util", - "tower", - "tower-lsp-macros", - "tracing", -] - -[[package]] -name = "tower-lsp-macros" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", -] - [[package]] name = "tower-service" version = "0.3.2" diff --git a/_typos.toml b/_typos.toml index 0199b9ef51f..2dc27542324 100644 --- a/_typos.toml +++ b/_typos.toml @@ -2,6 +2,8 @@ # Don't correct the shorthand for ByteArray. ba = "ba" compilability = "compilability" +# jod_thread crate in LS +jod = "jod" [files] extend-exclude = [ diff --git a/crates/cairo-lang-language-server/Cargo.toml b/crates/cairo-lang-language-server/Cargo.toml index af3d3a2b3f6..937cc5a4a69 100644 --- a/crates/cairo-lang-language-server/Cargo.toml +++ b/crates/cairo-lang-language-server/Cargo.toml @@ -25,21 +25,27 @@ cairo-lang-starknet = { path = "../cairo-lang-starknet", version = "~2.8.2" } cairo-lang-syntax = { path = "../cairo-lang-syntax", version = "~2.8.2" } cairo-lang-test-plugin = { path = "../cairo-lang-test-plugin", version = "~2.8.2" } cairo-lang-utils = { path = "../cairo-lang-utils", version = "~2.8.2" } +crossbeam = "0.8.4" indent.workspace = true indoc.workspace = true itertools.workspace = true +jod-thread = "0.1.2" +lsp-server = "0.7.7" +lsp-types = "=0.95.0" +rustc-hash = "1.1.0" salsa.workspace = true scarb-metadata = "1.12" serde = { workspace = true, default-features = true } serde_json.workspace = true smol_str.workspace = true tempfile = "3" -tokio.workspace = true -tower-lsp = "0.20.0" tracing = "0.1" tracing-chrome = "0.7.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } +[target.'cfg(target_vendor = "apple")'.dependencies] +libc = "0.2.155" + [dev-dependencies] assert_fs = "1.1" cairo-lang-language-server = { path = ".", features = ["testing"] } diff --git a/crates/cairo-lang-language-server/src/config.rs b/crates/cairo-lang-language-server/src/config.rs index 19be7ff2aa9..baa362da62e 100644 --- a/crates/cairo-lang-language-server/src/config.rs +++ b/crates/cairo-lang-language-server/src/config.rs @@ -2,12 +2,17 @@ use std::collections::VecDeque; use std::path::PathBuf; use anyhow::Context; +use lsp_server::ErrorCode; +use lsp_types::request::WorkspaceConfiguration; +use lsp_types::{ClientCapabilities, ConfigurationItem, ConfigurationParams}; use serde_json::Value; -use tower_lsp::lsp_types::{ClientCapabilities, ConfigurationItem}; -use tower_lsp::Client; use tracing::{debug, error, warn}; use crate::lsp::capabilities::client::ClientCapabilitiesExt; +use crate::server::api::{LSPResult, LSPResultEx}; +use crate::server::client::{Notifier, Requester}; +use crate::server::schedule::Task; +use crate::state::State; // TODO(mkaput): Write a macro that will auto-generate this struct and the `reload` logic. // TODO(mkaput): Write a test that checks that fields in this struct are sorted alphabetically. @@ -39,13 +44,18 @@ pub struct Config { impl Config { /// Reloads the configuration from the language client. #[tracing::instrument(name = "reload_config", level = "trace", skip_all)] - pub async fn reload(&mut self, client: &Client, client_capabilities: &ClientCapabilities) { + pub fn reload( + &mut self, + requester: &mut Requester<'_>, + client_capabilities: &ClientCapabilities, + on_reloaded: fn(&mut State, &Notifier), + ) -> LSPResult<()> { if !client_capabilities.workspace_configuration_support() { warn!( "client does not support `workspace/configuration` requests, config will not be \ reloaded" ); - return; + return Ok(()); } let items = vec![ @@ -56,34 +66,40 @@ impl Config { }, ]; let expected_len = items.len(); - if let Ok(response) = client - .configuration(items) - .await - .context("failed to query language client for configuration items") - .inspect_err(|e| warn!("{e:?}")) - { + + let handler = move |response: Vec| { let response_len = response.len(); if response_len != expected_len { error!( "server returned unexpected number of configuration items, expected: \ {expected_len}, got: {response_len}" ); - return; + return Task::nothing(); } // This conversion is O(1), and makes popping from front also O(1). let mut response = VecDeque::from(response); - self.unmanaged_core_path = response - .pop_front() - .as_ref() - .and_then(Value::as_str) - .filter(|s| !s.is_empty()) - .map(Into::into); - self.trace_macro_diagnostics = - response.pop_front().as_ref().and_then(Value::as_bool).unwrap_or_default(); + Task::local(move |state, notifier, _, _| { + state.config.unmanaged_core_path = response + .pop_front() + .as_ref() + .and_then(Value::as_str) + .filter(|s| !s.is_empty()) + .map(Into::into); + state.config.trace_macro_diagnostics = + response.pop_front().as_ref().and_then(Value::as_bool).unwrap_or_default(); - debug!("reloaded configuration: {self:#?}"); - } + debug!("reloaded configuration: {:#?}", state.config); + + on_reloaded(state, ¬ifier); + }) + }; + + requester + .request::(ConfigurationParams { items }, handler) + .context("failed to query language client for configuration items") + .with_failure_code(ErrorCode::RequestFailed) + .inspect_err(|e| warn!("{e:?}")) } } diff --git a/crates/cairo-lang-language-server/src/ide/code_actions/add_missing_trait.rs b/crates/cairo-lang-language-server/src/ide/code_actions/add_missing_trait.rs index 486e058bdb2..c7143dd8105 100644 --- a/crates/cairo-lang-language-server/src/ide/code_actions/add_missing_trait.rs +++ b/crates/cairo-lang-language-server/src/ide/code_actions/add_missing_trait.rs @@ -11,7 +11,7 @@ use cairo_lang_semantic::resolve::Resolver; use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::{ast, SyntaxNode, TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::{CodeAction, CodeActionKind, Range, TextEdit, Url, WorkspaceEdit}; +use lsp_types::{CodeAction, CodeActionKind, Range, TextEdit, Url, WorkspaceEdit}; use tracing::debug; use crate::ide::utils::find_methods_for_type; diff --git a/crates/cairo-lang-language-server/src/ide/code_actions/expand_macro.rs b/crates/cairo-lang-language-server/src/ide/code_actions/expand_macro.rs index 219777a3f18..b0d5d40de82 100644 --- a/crates/cairo-lang-language-server/src/ide/code_actions/expand_macro.rs +++ b/crates/cairo-lang-language-server/src/ide/code_actions/expand_macro.rs @@ -1,6 +1,6 @@ use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::SyntaxNode; -use tower_lsp::lsp_types::{CodeAction, Command}; +use lsp_types::{CodeAction, Command}; use crate::lang::db::{AnalysisDatabase, LsSyntaxGroup}; diff --git a/crates/cairo-lang-language-server/src/ide/code_actions/mod.rs b/crates/cairo-lang-language-server/src/ide/code_actions/mod.rs index 9776423e3db..1b983a7502d 100644 --- a/crates/cairo-lang-language-server/src/ide/code_actions/mod.rs +++ b/crates/cairo-lang-language-server/src/ide/code_actions/mod.rs @@ -1,5 +1,5 @@ use cairo_lang_syntax::node::SyntaxNode; -use tower_lsp::lsp_types::{ +use lsp_types::{ CodeAction, CodeActionOrCommand, CodeActionParams, CodeActionResponse, Diagnostic, NumberOrString, }; diff --git a/crates/cairo-lang-language-server/src/ide/code_actions/rename_unused_variable.rs b/crates/cairo-lang-language-server/src/ide/code_actions/rename_unused_variable.rs index 6e00714b9f9..2e9c833c3cd 100644 --- a/crates/cairo-lang-language-server/src/ide/code_actions/rename_unused_variable.rs +++ b/crates/cairo-lang-language-server/src/ide/code_actions/rename_unused_variable.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use cairo_lang_syntax::node::SyntaxNode; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::{CodeAction, Diagnostic, TextEdit, Url, WorkspaceEdit}; +use lsp_types::{CodeAction, Diagnostic, TextEdit, Url, WorkspaceEdit}; use crate::lang::db::AnalysisDatabase; diff --git a/crates/cairo-lang-language-server/src/ide/completion/completions.rs b/crates/cairo-lang-language-server/src/ide/completion/completions.rs index f0ad3d21eee..d3ab62f77c9 100644 --- a/crates/cairo-lang-language-server/src/ide/completion/completions.rs +++ b/crates/cairo-lang-language-server/src/ide/completion/completions.rs @@ -19,7 +19,7 @@ use cairo_lang_semantic::{ConcreteTypeId, Pattern, TypeLongId}; use cairo_lang_syntax::node::ast::PathSegment; use cairo_lang_syntax::node::{ast, TypedStablePtr, TypedSyntaxNode}; use cairo_lang_utils::{LookupIntern, Upcast}; -use tower_lsp::lsp_types::{CompletionItem, CompletionItemKind, Position, Range, TextEdit}; +use lsp_types::{CompletionItem, CompletionItemKind, InsertTextFormat, Position, Range, TextEdit}; use tracing::debug; use crate::ide::utils::find_methods_for_type; @@ -282,7 +282,7 @@ pub fn completion_for_method( let completion = CompletionItem { label: format!("{}()", name), insert_text: Some(format!("{}($0)", name)), - insert_text_format: Some(tower_lsp::lsp_types::InsertTextFormat::SNIPPET), + insert_text_format: Some(InsertTextFormat::SNIPPET), detail: Some(detail), kind: Some(CompletionItemKind::METHOD), additional_text_edits: Some(additional_text_edits), diff --git a/crates/cairo-lang-language-server/src/ide/completion/mod.rs b/crates/cairo-lang-language-server/src/ide/completion/mod.rs index 3ccebdeb20f..f853e51226b 100644 --- a/crates/cairo-lang-language-server/src/ide/completion/mod.rs +++ b/crates/cairo-lang-language-server/src/ide/completion/mod.rs @@ -5,7 +5,7 @@ use cairo_lang_syntax::node::db::SyntaxGroup; use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::{ast, SyntaxNode, TypedSyntaxNode}; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::{CompletionParams, CompletionResponse, CompletionTriggerKind}; +use lsp_types::{CompletionParams, CompletionResponse, CompletionTriggerKind}; use tracing::debug; use self::completions::{colon_colon_completions, dot_completions, generic_completions}; diff --git a/crates/cairo-lang-language-server/src/ide/formatter.rs b/crates/cairo-lang-language-server/src/ide/formatter.rs index a918addedd4..bc4083b3486 100644 --- a/crates/cairo-lang-language-server/src/ide/formatter.rs +++ b/crates/cairo-lang-language-server/src/ide/formatter.rs @@ -2,7 +2,7 @@ use cairo_lang_filesystem::db::FilesGroup; use cairo_lang_formatter::{get_formatted_file, FormatterConfig}; use cairo_lang_parser::db::ParserGroup; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::{DocumentFormattingParams, Position, Range, TextEdit}; +use lsp_types::{DocumentFormattingParams, Position, Range, TextEdit}; use tracing::error; use crate::lang::db::AnalysisDatabase; diff --git a/crates/cairo-lang-language-server/src/ide/hover/mod.rs b/crates/cairo-lang-language-server/src/ide/hover/mod.rs index 09436dac51b..aa33437e5a8 100644 --- a/crates/cairo-lang-language-server/src/ide/hover/mod.rs +++ b/crates/cairo-lang-language-server/src/ide/hover/mod.rs @@ -1,4 +1,4 @@ -use tower_lsp::lsp_types::{Hover, HoverContents, HoverParams, MarkupContent, MarkupKind}; +use lsp_types::{Hover, HoverContents, HoverParams, MarkupContent, MarkupKind}; use crate::lang::db::{AnalysisDatabase, LsSyntaxGroup}; use crate::lang::lsp::{LsProtoGroup, ToCairo}; diff --git a/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs b/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs index c1c11fe1054..f3a8592d606 100644 --- a/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs +++ b/crates/cairo-lang-language-server/src/ide/hover/render/definition.rs @@ -4,7 +4,7 @@ use cairo_lang_filesystem::ids::FileId; use cairo_lang_syntax::node::ast::TerminalIdentifier; use cairo_lang_syntax::node::TypedSyntaxNode; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::Hover; +use lsp_types::Hover; use crate::ide::hover::markdown_contents; use crate::ide::hover::render::markdown::{fenced_code_block, RULE}; diff --git a/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs b/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs index 27608be8dc5..842abb07533 100644 --- a/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs +++ b/crates/cairo-lang-language-server/src/ide/hover/render/legacy.rs @@ -8,7 +8,7 @@ use cairo_lang_syntax::node::ast::{Expr, Pattern, TerminalIdentifier}; use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::{SyntaxNode, TypedSyntaxNode}; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::Hover; +use lsp_types::Hover; use crate::ide::hover::markdown_contents; use crate::ide::hover::render::markdown::{fenced_code_block, RULE}; diff --git a/crates/cairo-lang-language-server/src/ide/macros/expand.rs b/crates/cairo-lang-language-server/src/ide/macros/expand.rs index e375991d8af..3123d242a2a 100644 --- a/crates/cairo-lang-language-server/src/ide/macros/expand.rs +++ b/crates/cairo-lang-language-server/src/ide/macros/expand.rs @@ -16,7 +16,7 @@ use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::{SyntaxNode, TypedSyntaxNode}; use cairo_lang_utils::Intern; use indoc::formatdoc; -use tower_lsp::lsp_types::TextDocumentPositionParams; +use lsp_types::TextDocumentPositionParams; use crate::lang::db::{AnalysisDatabase, LsSemanticGroup, LsSyntaxGroup}; use crate::lang::lsp::{LsProtoGroup, ToCairo}; diff --git a/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs b/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs index 2dfcb368b01..ed9ed01e130 100644 --- a/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs +++ b/crates/cairo-lang-language-server/src/ide/navigation/goto_definition.rs @@ -2,7 +2,7 @@ use cairo_lang_filesystem::db::get_originating_location; use cairo_lang_filesystem::ids::FileId; use cairo_lang_filesystem::span::{TextPosition, TextSpan}; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::{GotoDefinitionParams, GotoDefinitionResponse, Location}; +use lsp_types::{GotoDefinitionParams, GotoDefinitionResponse, Location}; use crate::lang::db::{AnalysisDatabase, LsSemanticGroup, LsSyntaxGroup}; use crate::lang::inspect::defs::find_definition; diff --git a/crates/cairo-lang-language-server/src/ide/semantic_highlighting/mod.rs b/crates/cairo-lang-language-server/src/ide/semantic_highlighting/mod.rs index c96ce822643..cc04ccae0d6 100644 --- a/crates/cairo-lang-language-server/src/ide/semantic_highlighting/mod.rs +++ b/crates/cairo-lang-language-server/src/ide/semantic_highlighting/mod.rs @@ -5,7 +5,7 @@ use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::{ast, SyntaxNode, TypedSyntaxNode}; use cairo_lang_utils::unordered_hash_map::UnorderedHashMap; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::*; +use lsp_types::{SemanticToken, SemanticTokens, SemanticTokensParams, SemanticTokensResult}; use tracing::error; use self::encoder::{EncodedToken, TokenEncoder}; diff --git a/crates/cairo-lang-language-server/src/ide/semantic_highlighting/token_kind.rs b/crates/cairo-lang-language-server/src/ide/semantic_highlighting/token_kind.rs index e7df0990b4e..13e35aa13d6 100644 --- a/crates/cairo-lang-language-server/src/ide/semantic_highlighting/token_kind.rs +++ b/crates/cairo-lang-language-server/src/ide/semantic_highlighting/token_kind.rs @@ -8,7 +8,7 @@ use cairo_lang_syntax::node::kind::SyntaxKind; use cairo_lang_syntax::node::utils::grandparent_kind; use cairo_lang_syntax::node::{ast, SyntaxNode, Terminal, TypedSyntaxNode}; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::SemanticTokenType; +use lsp_types::SemanticTokenType; use crate::lang::db::{AnalysisDatabase, LsSemanticGroup}; diff --git a/crates/cairo-lang-language-server/src/lang/diagnostics/lsp.rs b/crates/cairo-lang-language-server/src/lang/diagnostics/lsp.rs index 0799a0e1438..b5d9e621d4e 100644 --- a/crates/cairo-lang-language-server/src/lang/diagnostics/lsp.rs +++ b/crates/cairo-lang-language-server/src/lang/diagnostics/lsp.rs @@ -2,7 +2,7 @@ use cairo_lang_diagnostics::{DiagnosticEntry, DiagnosticLocation, Diagnostics, S use cairo_lang_filesystem::db::FilesGroup; use cairo_lang_filesystem::ids::FileId; use cairo_lang_utils::Upcast; -use tower_lsp::lsp_types::{ +use lsp_types::{ Diagnostic, DiagnosticRelatedInformation, DiagnosticSeverity, Location, NumberOrString, Range, }; use tracing::error; diff --git a/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group.rs b/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group.rs index abd17e18d4f..55a60686e10 100644 --- a/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group.rs +++ b/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group.rs @@ -1,8 +1,8 @@ use cairo_lang_filesystem::db::FilesGroup; use cairo_lang_filesystem::ids::{FileId, FileLongId}; use cairo_lang_utils::Upcast; +use lsp_types::Url; use salsa::InternKey; -use tower_lsp::lsp_types::Url; use tracing::error; #[cfg(test)] @@ -23,18 +23,18 @@ pub trait LsProtoGroup: Upcast { "vfs" => uri .host_str() .or_else(|| { - error!("invalid vfs url, missing host string: {uri}"); + error!("invalid vfs url, missing host string: {uri:?}"); None })? .parse::() .inspect_err(|e| { - error!("invalid vfs url, host string is not a valid integer, {e}: {uri}") + error!("invalid vfs url, host string is not a valid integer, {e}: {uri:?}") }) .ok() .map(Into::into) .map(FileId::from_intern_id), _ => { - error!("invalid url, scheme is not supported by this language server: {uri}"); + error!("invalid url, scheme is not supported by this language server: {uri:?}"); None } } diff --git a/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group_test.rs b/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group_test.rs index 59e5114bf14..ba441e0a9c0 100644 --- a/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group_test.rs +++ b/crates/cairo-lang-language-server/src/lang/lsp/ls_proto_group_test.rs @@ -1,7 +1,7 @@ use cairo_lang_filesystem::db::FilesGroup; use cairo_lang_filesystem::ids::{FileKind, FileLongId, VirtualFile}; use cairo_lang_filesystem::test_utils::FilesDatabaseForTesting; -use tower_lsp::lsp_types::Url; +use lsp_types::Url; use super::LsProtoGroup; diff --git a/crates/cairo-lang-language-server/src/lang/lsp/to_lsp.rs b/crates/cairo-lang-language-server/src/lang/lsp/to_lsp.rs index a99cd6b21cd..5e926679db7 100644 --- a/crates/cairo-lang-language-server/src/lang/lsp/to_lsp.rs +++ b/crates/cairo-lang-language-server/src/lang/lsp/to_lsp.rs @@ -1,5 +1,5 @@ use cairo_lang_filesystem::span::{TextPosition, TextPositionSpan}; -use tower_lsp::lsp_types::{Position, Range}; +use lsp_types::{Position, Range}; /// Convert a type into its LSP equivalent. /// diff --git a/crates/cairo-lang-language-server/src/lib.rs b/crates/cairo-lang-language-server/src/lib.rs index dc91f7c0439..e902c1c3f65 100644 --- a/crates/cairo-lang-language-server/src/lib.rs +++ b/crates/cairo-lang-language-server/src/lib.rs @@ -40,12 +40,13 @@ use std::collections::{HashMap, HashSet}; use std::io; +use std::num::NonZeroUsize; use std::panic::{catch_unwind, AssertUnwindSafe, RefUnwindSafe}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::time::SystemTime; -use anyhow::Context; +use anyhow::{anyhow, Context, Result}; use cairo_lang_compiler::db::validate_corelib; use cairo_lang_compiler::project::{setup_project, update_crate_roots_from_project_config}; use cairo_lang_defs::db::DefsGroup; @@ -63,27 +64,34 @@ use cairo_lang_semantic::SemanticDiagnostic; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; use cairo_lang_utils::{Intern, LookupIntern, Upcast}; use itertools::Itertools; -use salsa::{Cancelled, ParallelDatabase}; -use state::{FileDiagnostics, StateSnapshot}; -use tokio::sync::Semaphore; -use tokio::task::spawn_blocking; -use tower_lsp::jsonrpc::{Error as LSPError, Result as LSPResult}; -use tower_lsp::lsp_types::request::Request; -use tower_lsp::lsp_types::{TextDocumentPositionParams, Url}; -use tower_lsp::{Client, ClientSocket, LanguageServer, LspService, Server}; -use tracing::{debug, error, info, trace_span, warn, Instrument}; +use lsp_server::{ErrorCode, Message}; +use lsp_types::notification::PublishDiagnostics; +use lsp_types::{ + ClientCapabilities, PublishDiagnosticsParams, Registration, RegistrationParams, Url, +}; +use salsa::Cancelled; +use server::connection::ClientSender; +use server::schedule::task::SyncTask; +use server::schedule::thread::JoinHandle; +use state::FileDiagnostics; +use tracing::{debug, error, info, trace_span, warn}; use crate::config::Config; use crate::lang::db::AnalysisDatabase; use crate::lang::diagnostics::lsp::map_cairo_diagnostics_to_lsp; use crate::lang::lsp::LsProtoGroup; -use crate::lsp::ext::{ - CorelibVersionMismatch, ProvideVirtualFileRequest, ProvideVirtualFileResponse, +use crate::lsp::capabilities::server::{ + collect_dynamic_registrations, collect_server_capabilities, }; +use crate::lsp::ext::CorelibVersionMismatch; use crate::project::scarb::update_crate_roots; use crate::project::unmanaged_core_crate::try_to_init_unmanaged_core; use crate::project::ProjectManifestPath; -use crate::server::notifier::Notifier; +use crate::server::api; +use crate::server::api::{LSPError, LSPResult}; +use crate::server::client::{Client, Notifier, Requester}; +use crate::server::connection::{Connection, ConnectionInitializer}; +use crate::server::schedule::{event_loop_thread, Scheduler, Task}; use crate::state::State; use crate::toolchain::scarb::ScarbToolchain; @@ -120,41 +128,41 @@ pub fn start() { start_with_tricks(Tricks::default()); } -/// Number of LSP requests that can be processed concurrently. -/// Higher number than default tower_lsp::DEFAULT_MAX_CONCURRENCY = 4. -/// This is increased because we don't have to limit requests this way now. -/// Cancellation will skip requests that are no longer relevant so only latest ones will be -/// processed. Effectively there will be similar number of requests processed at once, but under -/// heavy load these will be more actual ones. -const REQUESTS_PROCESSED_CONCURRENTLY: usize = 100; - /// Starts the language server with customizations. /// /// See [the top-level documentation][lib] documentation for usage examples. /// /// [lib]: crate#running-with-customizations -#[tokio::main] -pub async fn start_with_tricks(tricks: Tricks) { +pub fn start_with_tricks(tricks: Tricks) { let _log_guard = init_logging(); info!("language server starting"); env_config::report_to_logs(); - let (stdin, stdout) = (tokio::io::stdin(), tokio::io::stdout()); - - let (service, socket) = Backend::build_service(tricks); - Server::new(stdin, stdout, socket) - .concurrency_level(REQUESTS_PROCESSED_CONCURRENTLY) - .serve(service) - .await; + let exit_code = match Backend::new(tricks) { + Ok(backend) => { + if let Err(err) = backend.run().map(|handle| handle.join()) { + error!("language server encountered an unrecoverable error: {err}"); + 1 + } else { + 0 + } + } + Err(err) => { + error!("language server failed during initialization: {err}"); + 1 + } + }; info!("language server stopped"); + std::process::exit(exit_code); } /// Special function to run the language server in end-to-end tests. #[cfg(feature = "testing")] -pub fn build_service_for_e2e_tests() -> (LspService, ClientSocket) { - Backend::build_service(Tricks::default()) +pub fn build_service_for_e2e_tests() -> (Box Backend + Send>, lsp_server::Connection) +{ + Backend::new_for_testing(Default::default()) } /// Initialize logging infrastructure for the language server. @@ -164,7 +172,7 @@ fn init_logging() -> Option { use std::fs; use std::io::IsTerminal; - use tracing_chrome::{ChromeLayerBuilder, TraceStyle}; + use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::filter::{EnvFilter, LevelFilter}; use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::fmt::time::Uptime; @@ -204,11 +212,8 @@ fn init_logging() -> Option { "open that file with https://ui.perfetto.dev (or chrome://tracing) to analyze it" ); - let (profile_layer, profile_layer_guard) = ChromeLayerBuilder::new() - .writer(profile_file) - .trace_style(TraceStyle::Async) - .include_args(true) - .build(); + let (profile_layer, profile_layer_guard) = + ChromeLayerBuilder::new().writer(profile_file).include_args(true).build(); guard = Some(profile_layer_guard); Some(profile_layer) @@ -226,15 +231,15 @@ fn init_logging() -> Option { /// Makes sure that all open files exist in the new db, with their current changes. #[tracing::instrument(level = "trace", skip_all)] -fn ensure_exists_in_db( +fn ensure_exists_in_db<'a>( new_db: &mut AnalysisDatabase, old_db: &AnalysisDatabase, - open_files: impl Iterator, + open_files: impl Iterator, ) { let overrides = old_db.file_overrides(); let mut new_overrides: OrderedHashMap> = Default::default(); for uri in open_files { - let Some(file_id) = old_db.file_for_url(&uri) else { continue }; + let Some(file_id) = old_db.file_for_url(uri) else { continue }; let new_file_id = file_id.lookup_intern(old_db).intern(new_db); if let Some(content) = overrides.get(&file_id) { new_overrides.insert(new_file_id, content.clone()); @@ -243,158 +248,204 @@ fn ensure_exists_in_db( new_db.set_file_overrides(Arc::new(new_overrides)); } +#[cfg(not(feature = "testing"))] struct Backend { - client: Client, - tricks: Tricks, - // Lock making sure there is at most a single "diagnostic refresh" thread. - refresh_lock: tokio::sync::Mutex<()>, - // Semaphore making sure there are at most one worker and one waiter for refresh. - refresh_waiters_semaphore: tokio::sync::Semaphore, - state_mutex: tokio::sync::Mutex, - scarb_toolchain: ScarbToolchain, - last_replace: tokio::sync::Mutex, - db_replace_interval: Duration, + connection: Connection, + state: State, } -/// TODO: Remove when we move to sync world. -/// This is macro because of lifetimes problems with `self`. -macro_rules! state_mut_async { - ($state:ident, $this:ident, $($f:tt)+) => { - async { - let mut state = $this.state_mutex.lock().await; - let $state = &mut *state; - - $($f)+ - } - }; +#[cfg(feature = "testing")] +pub struct Backend { + connection: Connection, + state: State, } impl Backend { - fn build_service(tricks: Tricks) -> (LspService, ClientSocket) { - LspService::build(|client| Self::new(client, tricks)) - .custom_method(lsp::ext::ProvideVirtualFile::METHOD, Self::vfs_provide) - .custom_method(lsp::ext::ViewAnalyzedCrates::METHOD, Self::view_analyzed_crates) - .custom_method(lsp::ext::ExpandMacro::METHOD, Self::expand_macro) - .finish() + fn new(tricks: Tricks) -> Result { + let connection_initializer = ConnectionInitializer::stdio(); + + Self::new_inner(tricks, connection_initializer) + } + + #[cfg(feature = "testing")] + fn new_for_testing( + tricks: Tricks, + ) -> (Box Self + Send>, lsp_server::Connection) { + let (connection_initializer, client) = ConnectionInitializer::memory(); + + let init = Box::new(|| Self::new_inner(tricks, connection_initializer).unwrap()); + + (init, client) + } + + fn new_inner(tricks: Tricks, connection_initializer: ConnectionInitializer) -> Result { + let (id, init_params) = connection_initializer.initialize_start()?; + + let client_capabilities = init_params.capabilities; + let server_capabilities = collect_server_capabilities(&client_capabilities); + + let connection = connection_initializer.initialize_finish(id, server_capabilities)?; + let state = Self::create_state(connection.make_sender(), client_capabilities, tricks); + + Ok(Self { connection, state }) } - fn new(client: Client, tricks: Tricks) -> Self { + fn create_state( + sender: ClientSender, + client_capabilities: ClientCapabilities, + tricks: Tricks, + ) -> State { let db = AnalysisDatabase::new(&tricks); - let notifier = Notifier::new(&client); - let scarb_toolchain = ScarbToolchain::new(¬ifier); - Self { - client, - tricks, - refresh_lock: Default::default(), - refresh_waiters_semaphore: Semaphore::new(2), - state_mutex: State::new(db).into(), - scarb_toolchain, - last_replace: tokio::sync::Mutex::new(SystemTime::now()), - db_replace_interval: env_config::db_replace_interval(), - } + let notifier = Client::new(sender).notifier(); + let scarb_toolchain = ScarbToolchain::new(notifier); + + State::new(db, client_capabilities, scarb_toolchain, tricks) } - /// Catches panics and returns Err. - async fn catch_panics(&self, f: F) -> LSPResult - where - F: FnOnce() -> T + Send + 'static, - T: Send + 'static, - { - spawn_blocking(move || { - catch_unwind(AssertUnwindSafe(f)).map_err(|err| { - // Salsa is broken and sometimes when cancelled throws regular assert instead of - // [`Cancelled`]. Catch this case too. - if err.is::() - || err.downcast_ref::<&str>().is_some_and(|msg| { - msg.contains( - "assertion failed: old_memo.revisions.changed_at <= \ - revisions.changed_at", - ) - }) - { - debug!("LSP worker thread was cancelled"); - LSPError::request_cancelled() - } else { - error!("caught panic in LSP worker thread"); - LSPError::internal_error() - } - }) - }) - .await - .unwrap_or_else(|_| { - error!("failed to join LSP worker thread"); - Err(LSPError::internal_error()) + fn run(self) -> Result>> { + let Self { mut state, connection } = self; + + event_loop_thread(move || { + let scheduler = Self::initial_setup(&mut state, &connection); + + let result = Self::event_loop(&connection, scheduler); + + if let Err(err) = connection.close() { + error!("failed to close connection to the language server: {err:?}"); + } + + result }) } - /// Locks and gets a server state. - #[tracing::instrument(level = "trace", skip_all)] - async fn with_state_mut(&self, f: F) -> T - where - F: FnOnce(&mut State) -> T, - { - let mut state = self.state_mutex.lock().await; + #[cfg(feature = "testing")] + pub fn run_for_tests(self) -> Result>> { + self.run() + } - f(&mut state) + fn initial_setup<'a>(state: &'a mut State, connection: &'_ Connection) -> Scheduler<'a> { + let four = NonZeroUsize::new(4).unwrap(); + // By default, we set the number of worker threads to `num_cpus`, with a maximum of 4. + let worker_threads = std::thread::available_parallelism().unwrap_or(four).max(four); + let dynamic_registrations = collect_dynamic_registrations(&state.client_capabilities); + + let mut scheduler = Scheduler::new(state, worker_threads, connection.make_sender()); + + if let Err(error) = + Self::register_dynamic_capabilities(&mut scheduler, dynamic_registrations) + { + error!( + "failed to register dynamic capabilities, some features may not work properly: \ + {error:?}" + ) + } + + // Reloading config has to be done as a sync task to access mutable state that is borrowed + // by scheduler. + scheduler.dispatch(Task::Sync(SyncTask { + func: Box::new(|state, _notifier, requester, _responder| { + Self::reload_config(state, requester).ok(); + }), + })); + + scheduler } - /// Locks and produces server state snapshot. - #[tracing::instrument(level = "trace", skip_all)] - async fn state_snapshot(&self) -> StateSnapshot { - self.with_state_mut(|state| state.snapshot()).await + fn register_dynamic_capabilities( + scheduler: &mut Scheduler<'_>, + registrations: Vec, + ) -> Result<()> { + let response_handler = |()| { + debug!("configuration file watcher successfully registered"); + Task::nothing() + }; + + scheduler.request::( + RegistrationParams { registrations }, + response_handler, + ) } - /// Locks and produces db snapshot. - #[tracing::instrument(level = "trace", skip_all)] - async fn db_snapshot(&self) -> salsa::Snapshot { - self.with_state_mut(|state| state.db.snapshot()).await + fn event_loop(connection: &Connection, mut scheduler: Scheduler<'_>) -> Result<()> { + for msg in connection.incoming() { + if connection.handle_shutdown(&msg)? { + break; + } + let task = match msg { + Message::Request(req) => api::request(req), + Message::Notification(notification) => api::notification(notification), + Message::Response(response) => scheduler.response(response), + }; + scheduler.dispatch(task); + } + + Ok(()) } - /// Refresh diagnostics and send diffs to client. - #[tracing::instrument(level = "debug", skip_all)] - async fn refresh_diagnostics(&self) -> LSPResult<()> { - // Making sure only a single thread is refreshing diagnostics at a time, and that at most - // one thread is waiting to start refreshing. This allows changed to be grouped - // together before querying the database, as well as releasing extra threads waiting to - // start diagnostics updates. - // TODO(orizi): Consider removing when request cancellation is supported. - let Ok(waiter_permit) = self.refresh_waiters_semaphore.try_acquire() else { return Ok(()) }; - let refresh_lock = self.refresh_lock.lock().await; + /// Catches panics and returns Err. + fn catch_panics(f: impl FnOnce() -> T) -> LSPResult { + catch_unwind(AssertUnwindSafe(f)).map_err(|err| { + // Salsa is broken and sometimes when cancelled throws regular assert instead of + // [`Cancelled`]. Catch this case too. + if err.is::() + || err.downcast_ref::<&str>().is_some_and(|msg| { + msg.contains( + "assertion failed: old_memo.revisions.changed_at <= revisions.changed_at", + ) + }) + { + debug!("LSP worker thread was cancelled"); + LSPError::new( + anyhow!("LSP worker thread was cancelled"), + ErrorCode::ServerCancelled, + ) + } else { + error!("caught panic in LSP worker thread"); + LSPError::new( + anyhow!("caught panic in LSP worker thread"), + ErrorCode::InternalError, + ) + } + }) + } + /// Refresh diagnostics and send diffs to the client. + #[tracing::instrument(level = "debug", skip_all)] + fn refresh_diagnostics(state: &mut State, notifier: &Notifier) -> LSPResult<()> { + // TODO(#6318): implement a pop queue of size 1 for diags let mut files_with_set_diagnostics: HashSet = HashSet::default(); let mut processed_modules: HashSet = HashSet::default(); - let open_files_ids: HashSet = async { - let state_snapshot = self.state_snapshot().await; - let open_files = state_snapshot.open_files.iter(); - open_files.filter_map(|url| state_snapshot.db.file_for_url(url)).collect() - } - .instrument(trace_span!("get_open_files_ids")) - .await; + let open_files_ids = trace_span!("get_open_files_ids").in_scope(|| { + state + .open_files + .iter() + .filter_map(|uri| state.db.file_for_url(uri)) + .collect::>() + }); - let open_files_modules = self.get_files_modules(open_files_ids.iter()).await; + let open_files_modules = + Backend::get_files_modules(&state.db, open_files_ids.iter().copied()); // Refresh open files modules first for better UX - async { + trace_span!("refresh_open_files_modules").in_scope(|| { for (file, file_modules_ids) in open_files_modules { - self.refresh_file_diagnostics( + Backend::refresh_file_diagnostics( + state, &file, &file_modules_ids, &mut processed_modules, &mut files_with_set_diagnostics, - ) - .await; + notifier, + ); } - } - .instrument(trace_span!("refresh_open_files_modules")) - .await; + }); - let rest_of_files = async { + let rest_of_files = trace_span!("get_rest_of_files").in_scope(|| { let mut rest_of_files: HashSet = HashSet::default(); - let db = self.db_snapshot().await; - for crate_id in db.crates() { - for module_id in db.crate_modules(crate_id).iter() { - if let Ok(module_files) = db.module_files(*module_id) { + for crate_id in state.db.crates() { + for module_id in state.db.crate_modules(crate_id).iter() { + if let Ok(module_files) = state.db.module_files(*module_id) { let unprocessed_files = module_files.iter().filter(|file| !open_files_ids.contains(file)); rest_of_files.extend(unprocessed_files); @@ -402,71 +453,65 @@ impl Backend { } } rest_of_files - } - .instrument(trace_span!("get_rest_of_files")) - .await; + }); - let rest_of_files_modules = self.get_files_modules(rest_of_files.iter()).await; + let rest_of_files_modules = + Backend::get_files_modules(&state.db, rest_of_files.iter().copied()); // Refresh rest of files after, since they are not viewed currently - async { + trace_span!("refresh_other_files_modules").in_scope(|| { for (file, file_modules_ids) in rest_of_files_modules { - self.refresh_file_diagnostics( + Backend::refresh_file_diagnostics( + state, &file, &file_modules_ids, &mut processed_modules, &mut files_with_set_diagnostics, - ) - .await; + notifier, + ); } - } - .instrument(trace_span!("refresh_other_files_modules")) - .await; + }); // Clear old diagnostics - async { + trace_span!("clear_old_diagnostics").in_scope(|| { let mut removed_files = Vec::new(); - self.with_state_mut(|s| { - s.file_diagnostics.retain(|uri, _| { - let retain = files_with_set_diagnostics.contains(uri); - if !retain { - removed_files.push(uri.clone()); - } - retain - }); - }) - .await; + + state.file_diagnostics.retain(|uri, _| { + let retain = files_with_set_diagnostics.contains(uri); + if !retain { + removed_files.push(uri.clone()); + } + retain + }); for file in removed_files { - self.client - .publish_diagnostics(file, Vec::new(), None) - .instrument(trace_span!("publish_diagnostics")) - .await; + trace_span!("publish_diagnostics").in_scope(|| { + notifier.notify::(PublishDiagnosticsParams { + uri: file, + diagnostics: vec![], + version: None, + }); + }); } - } - .instrument(trace_span!("clear_old_diagnostics")) - .await; + }); - // Release locks prior to potentially swapping the database. - drop(refresh_lock); - drop(waiter_permit); - // After handling of all diagnostics attempting to swap the database to reduce memory + // After handling of all diagnostics, attempting to swap the database to reduce memory // consumption. - self.maybe_swap_database().await + // This should be an independent cronjob when diagnostics are run as a background task. + Backend::maybe_swap_database(state, notifier) } /// Refresh diagnostics for a single file. - async fn refresh_file_diagnostics( - &self, + fn refresh_file_diagnostics( + state: &mut State, file: &FileId, modules_ids: &Vec, processed_modules: &mut HashSet, files_with_set_diagnostics: &mut HashSet, + notifier: &Notifier, ) { - let state = self.state_snapshot().await; - let db = state.db; - let config = state.config; - let file_url = db.url_for_file(*file); + let db = &state.db; + let file_uri = db.url_for_file(*file); let mut semantic_file_diagnostics: Vec = vec![]; let mut lowering_file_diagnostics: Vec = vec![]; @@ -476,7 +521,7 @@ impl Backend { catch_unwind(AssertUnwindSafe(|| $db.$query($file_id))) .map($f) .inspect_err(|_| { - error!("caught panic when computing diagnostics for file {file_url}"); + error!("caught panic when computing diagnostics for file {file_uri:?}"); }) .unwrap_or_default() }) @@ -507,29 +552,20 @@ impl Backend { }; if !new_file_diagnostics.is_empty() { - files_with_set_diagnostics.insert(file_url.clone()); + files_with_set_diagnostics.insert(file_uri.clone()); } // Since we are using Arcs, this comparison should be efficient. - let skip_update = self - .with_state_mut(|state| { - if let Some(old_file_diagnostics) = state.file_diagnostics.get(&file_url) { - if old_file_diagnostics == &new_file_diagnostics { - return true; - } - } - - state.file_diagnostics.insert(file_url.clone(), new_file_diagnostics.clone()); - false - }) - .await; + if let Some(old_file_diagnostics) = state.file_diagnostics.get(&file_uri) { + if old_file_diagnostics == &new_file_diagnostics { + return; + } - if skip_update { - return; - } + state.file_diagnostics.insert(file_uri.clone(), new_file_diagnostics.clone()); + }; let mut diags = Vec::new(); - let trace_macro_diagnostics = config.trace_macro_diagnostics; + let trace_macro_diagnostics = state.config.trace_macro_diagnostics; map_cairo_diagnostics_to_lsp( (*db).upcast(), &mut diags, @@ -549,25 +585,24 @@ impl Backend { trace_macro_diagnostics, ); - // Drop database snapshot before we wait for the client responding to our notification. - drop(db); - - self.client - .publish_diagnostics(file_url, diags, None) - .instrument(trace_span!("publish_diagnostics")) - .await; + trace_span!("publish_diagnostics").in_scope(|| { + notifier.notify::(PublishDiagnosticsParams { + uri: file_uri, + diagnostics: diags, + version: None, + }); + }) } /// Gets the mapping of files to their respective modules. - async fn get_files_modules( - &self, - files_ids: impl Iterator, + fn get_files_modules( + db: &AnalysisDatabase, + files_ids: impl Iterator, ) -> HashMap> { - let state_snapshot = self.state_snapshot().await; let mut result = HashMap::default(); for file_id in files_ids { - if let Ok(file_modules) = state_snapshot.db.file_modules(*file_id) { - result.insert(*file_id, file_modules.iter().cloned().collect_vec()); + if let Ok(file_modules) = db.file_modules(file_id) { + result.insert(file_id, file_modules.iter().cloned().collect_vec()); } } result @@ -575,50 +610,41 @@ impl Backend { /// Checks if enough time passed since last db swap, and if so, swaps the database. #[tracing::instrument(level = "trace", skip_all)] - async fn maybe_swap_database(&self) -> LSPResult<()> { - let Ok(mut last_replace) = self.last_replace.try_lock() else { - // Another thread is already swapping the database. - return Ok(()); - }; - if last_replace.elapsed().unwrap() <= self.db_replace_interval { + fn maybe_swap_database(state: &mut State, notifier: &Notifier) -> LSPResult<()> { + if state.last_replace.elapsed().unwrap() <= state.db_replace_interval { // Not enough time passed since last swap. return Ok(()); } - let result = self.swap_database().await; - *last_replace = SystemTime::now(); + + let result = Backend::swap_database(state, notifier); + + state.last_replace = SystemTime::now(); + result } /// Perform database swap #[tracing::instrument(level = "debug", skip_all)] - async fn swap_database(&self) -> LSPResult<()> { - let state = self.state_snapshot().await; - let open_files = state.open_files; - let config = &state.config; - + fn swap_database(state: &mut State, notifier: &Notifier) -> LSPResult<()> { debug!("scheduled"); - let mut new_db = self - .catch_panics({ - let open_files = open_files.clone(); - let tricks = self.tricks.clone(); - - move || { - let mut new_db = AnalysisDatabase::new(&tricks); - ensure_exists_in_db(&mut new_db, &state.db, open_files.iter().cloned()); - new_db - } - }) - .await?; + let mut new_db = Backend::catch_panics(|| { + let mut new_db = AnalysisDatabase::new(&state.tricks); + ensure_exists_in_db(&mut new_db, &state.db, state.open_files.iter()); + new_db + })?; debug!("initial setup done"); - self.ensure_diagnostics_queries_up_to_date(&mut new_db, config, open_files.iter().cloned()) - .await; + Backend::ensure_diagnostics_queries_up_to_date( + &mut new_db, + &state.scarb_toolchain, + &state.config, + state.open_files.iter(), + notifier, + ); debug!("initial compilation done"); debug!("starting"); - self.with_state_mut(|state| { - ensure_exists_in_db(&mut new_db, &state.db, state.open_files.iter().cloned()); - state.db = new_db; - }) - .await; + + ensure_exists_in_db(&mut new_db, &state.db, state.open_files.iter()); + state.db = new_db; debug!("done"); Ok(()) @@ -626,11 +652,12 @@ impl Backend { /// Ensures that all diagnostics are up to date. #[tracing::instrument(level = "trace", skip_all)] - async fn ensure_diagnostics_queries_up_to_date( - &self, + fn ensure_diagnostics_queries_up_to_date<'a>( db: &mut AnalysisDatabase, + scarb_toolchain: &ScarbToolchain, config: &Config, - open_files: impl Iterator, + open_files: impl Iterator, + notifier: &Notifier, ) { let query_diags = |db: &AnalysisDatabase, file_id| { db.file_syntax_diagnostics(file_id); @@ -638,9 +665,9 @@ impl Backend { let _ = db.file_lowering_diagnostics(file_id); }; for uri in open_files { - let Some(file_id) = db.file_for_url(&uri) else { continue }; + let Some(file_id) = db.file_for_url(uri) else { continue }; if let FileLongId::OnDisk(file_path) = file_id.lookup_intern(db) { - self.detect_crate_for(db, config, file_path).await; + Backend::detect_crate_for(db, scarb_toolchain, config, &file_path, notifier); } query_diags(db, file_id); } @@ -657,80 +684,38 @@ impl Backend { } } - #[tracing::instrument(level = "trace", skip_all)] - async fn view_analyzed_crates(&self) -> LSPResult { - let db = self.db_snapshot().await; - self.catch_panics(move || lang::inspect::crates::inspect_analyzed_crates(&db)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn expand_macro(&self, params: TextDocumentPositionParams) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::macros::expand::expand_macro(&db, ¶ms)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn vfs_provide( - &self, - params: ProvideVirtualFileRequest, - ) -> LSPResult { - let db = self.db_snapshot().await; - self.catch_panics(move || { - let content = db - .file_for_url(¶ms.uri) - .and_then(|file_id| db.file_content(file_id)) - .map(|content| content.to_string()); - ProvideVirtualFileResponse { content } - }) - .await - } - /// Tries to detect the crate root the config that contains a cairo file, and add it to the /// system. #[tracing::instrument(level = "trace", skip_all)] - async fn detect_crate_for( - &self, + fn detect_crate_for( db: &mut AnalysisDatabase, + scarb_toolchain: &ScarbToolchain, config: &Config, - file_path: PathBuf, + file_path: &Path, + notifier: &Notifier, ) { - match ProjectManifestPath::discover(&file_path) { + match ProjectManifestPath::discover(file_path) { Some(ProjectManifestPath::Scarb(manifest_path)) => { - let Ok(metadata) = spawn_blocking({ - let scarb = self.scarb_toolchain.clone(); - move || { - scarb - .metadata(&manifest_path) - .with_context(|| { - format!( - "failed to refresh scarb workspace: {}", - manifest_path.display() - ) - }) - .inspect_err(|e| { - // TODO(mkaput): Send a notification to the language client. - warn!("{e:?}"); - }) - .ok() - } - }) - .await - else { - error!("scarb invoking thread panicked"); - return; - }; + let metadata = scarb_toolchain + .metadata(&manifest_path) + .with_context(|| { + format!("failed to refresh scarb workspace: {}", manifest_path.display()) + }) + .inspect_err(|e| { + // TODO(mkaput): Send a notification to the language client. + warn!("{e:?}"); + }) + .ok(); if let Some(metadata) = metadata { update_crate_roots(&metadata, db); } else { // Try to set up a corelib at least. - try_to_init_unmanaged_core(db, config, &self.scarb_toolchain); + try_to_init_unmanaged_core(db, config, scarb_toolchain); } if let Err(result) = validate_corelib(db) { - self.client - .send_notification::(result.to_string()) - .await; + notifier.notify::(result.to_string()); } } @@ -739,7 +724,7 @@ impl Backend { // DB will also be absolute. assert!(config_path.is_absolute()); - try_to_init_unmanaged_core(db, config, &self.scarb_toolchain); + try_to_init_unmanaged_core(db, config, scarb_toolchain); if let Ok(config) = ProjectConfig::from_file(&config_path) { update_crate_roots_from_project_config(db, &config); @@ -747,9 +732,9 @@ impl Backend { } None => { - try_to_init_unmanaged_core(db, config, &self.scarb_toolchain); + try_to_init_unmanaged_core(db, config, scarb_toolchain); - if let Err(err) = setup_project(&mut *db, file_path.as_path()) { + if let Err(err) = setup_project(&mut *db, file_path) { let file_path_s = file_path.to_string_lossy(); error!("error loading file {file_path_s} as a single crate: {err}"); } @@ -759,31 +744,33 @@ impl Backend { /// Reload crate detection for all open files. #[tracing::instrument(level = "trace", skip_all)] - async fn reload(&self) -> LSPResult<()> { - self.reload_config().await; - - state_mut_async! {state, self, - let db = &mut state.db; - - for uri in state.open_files.iter() { - let Some(file_id) = db.file_for_url(uri) else { continue }; - if let FileLongId::OnDisk(file_path) = db.lookup_intern_file(file_id) { - self.detect_crate_for(db, &state.config, file_path).await; - } + fn reload( + state: &mut State, + notifier: &Notifier, + requester: &mut Requester<'_>, + ) -> LSPResult<()> { + Backend::reload_config(state, requester)?; + + for uri in state.open_files.iter() { + let Some(file_id) = state.db.file_for_url(uri) else { continue }; + if let FileLongId::OnDisk(file_path) = state.db.lookup_intern_file(file_id) { + Backend::detect_crate_for( + &mut state.db, + &state.scarb_toolchain, + &state.config, + &file_path, + notifier, + ); } } - .await; - self.refresh_diagnostics().await + Backend::refresh_diagnostics(state, notifier) } /// Reload the [`Config`] and all its dependencies. - async fn reload_config(&self) { - state_mut_async! {state, self, - state.config.reload(&self.client, &state.client_capabilities).await; - } - .await; - - self.refresh_diagnostics().await.ok(); + fn reload_config(state: &mut State, requester: &mut Requester<'_>) -> LSPResult<()> { + state.config.reload(requester, &state.client_capabilities, |state, notifier| { + Backend::refresh_diagnostics(state, notifier).ok(); + }) } } diff --git a/crates/cairo-lang-language-server/src/lsp/capabilities/client.rs b/crates/cairo-lang-language-server/src/lsp/capabilities/client.rs index 1a79aa98473..396cec18ae3 100644 --- a/crates/cairo-lang-language-server/src/lsp/capabilities/client.rs +++ b/crates/cairo-lang-language-server/src/lsp/capabilities/client.rs @@ -1,4 +1,4 @@ -use tower_lsp::lsp_types::ClientCapabilities; +use lsp_types::ClientCapabilities; macro_rules! try_or_default { ($expr:expr) => { diff --git a/crates/cairo-lang-language-server/src/lsp/capabilities/server.rs b/crates/cairo-lang-language-server/src/lsp/capabilities/server.rs index 4bdc147109a..aa411f73b2f 100644 --- a/crates/cairo-lang-language-server/src/lsp/capabilities/server.rs +++ b/crates/cairo-lang-language-server/src/lsp/capabilities/server.rs @@ -13,12 +13,7 @@ use std::ops::Not; -use missing_lsp_types::{ - CodeActionRegistrationOptions, DefinitionRegistrationOptions, - DocumentFormattingRegistrationOptions, -}; -use serde::Serialize; -use tower_lsp::lsp_types::{ +use lsp_types::{ ClientCapabilities, CodeActionProviderCapability, CompletionOptions, CompletionRegistrationOptions, DefinitionOptions, DidChangeWatchedFilesRegistrationOptions, DocumentFilter, ExecuteCommandOptions, ExecuteCommandRegistrationOptions, FileSystemWatcher, @@ -29,6 +24,11 @@ use tower_lsp::lsp_types::{ TextDocumentSyncCapability, TextDocumentSyncKind, TextDocumentSyncOptions, TextDocumentSyncSaveOptions, }; +use missing_lsp_types::{ + CodeActionRegistrationOptions, DefinitionRegistrationOptions, + DocumentFormattingRegistrationOptions, +}; +use serde::Serialize; use crate::ide::semantic_highlighting::SemanticTokenKind; use crate::lsp::capabilities::client::ClientCapabilitiesExt; @@ -257,11 +257,11 @@ fn create_registration(method: &str, registration_options: impl Serialize) -> Re } mod missing_lsp_types { - use serde::{Deserialize, Serialize}; - use tower_lsp::lsp_types::{ + use lsp_types::{ CodeActionOptions, DefinitionOptions, DocumentFormattingOptions, TextDocumentRegistrationOptions, }; + use serde::{Deserialize, Serialize}; #[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] diff --git a/crates/cairo-lang-language-server/src/lsp/controller.rs b/crates/cairo-lang-language-server/src/lsp/controller.rs index d4543c57057..8b137891791 100644 --- a/crates/cairo-lang-language-server/src/lsp/controller.rs +++ b/crates/cairo-lang-language-server/src/lsp/controller.rs @@ -1,259 +1 @@ -use std::sync::Arc; -use cairo_lang_filesystem::db::{AsFilesGroupMut, FilesGroupEx, PrivRawFileContentQuery}; -use serde_json::Value; -use tower_lsp::jsonrpc::Result as LSPResult; -use tower_lsp::lsp_types::{ - CodeActionParams, CodeActionResponse, CompletionParams, CompletionResponse, - DidChangeConfigurationParams, DidChangeTextDocumentParams, DidChangeWatchedFilesParams, - DidCloseTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams, - DocumentFormattingParams, ExecuteCommandParams, GotoDefinitionParams, GotoDefinitionResponse, - Hover, HoverParams, InitializeParams, InitializeResult, InitializedParams, MessageType, - SemanticTokensParams, SemanticTokensResult, TextDocumentContentChangeEvent, TextEdit, Url, - WorkspaceEdit, -}; -use tower_lsp::LanguageServer; -use tracing::{error, warn}; - -use crate::lang::lsp::LsProtoGroup; -use crate::lsp::capabilities::server::{ - collect_dynamic_registrations, collect_server_capabilities, -}; -use crate::server::commands::ServerCommands; -use crate::state::Owned; -use crate::{ide, Backend}; - -/// TODO: Remove when we move to sync world. -/// This is macro because of lifetimes problems with `self`. -macro_rules! state_mut_async { - ($state:ident, $this:ident, $($f:tt)+) => { - async { - let mut state = $this.state_mutex.lock().await; - let $state = &mut *state; - - $($f)+ - } - }; -} - -#[tower_lsp::async_trait] -impl LanguageServer for Backend { - #[tracing::instrument(level = "debug", skip_all)] - async fn initialize(&self, params: InitializeParams) -> LSPResult { - let client_capabilities = Owned::new(Arc::new(params.capabilities)); - let client_capabilities_snapshot = client_capabilities.snapshot(); - self.with_state_mut(move |state| { - state.client_capabilities = client_capabilities; - }) - .await; - - Ok(InitializeResult { - server_info: None, - capabilities: collect_server_capabilities(&client_capabilities_snapshot), - }) - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn initialized(&self, _: InitializedParams) { - // Initialize the configuration. - self.reload_config().await; - - // Dynamically register capabilities. - let client_capabilities = self.state_snapshot().await.client_capabilities; - - let dynamic_registrations = collect_dynamic_registrations(&client_capabilities); - if !dynamic_registrations.is_empty() { - let result = self.client.register_capability(dynamic_registrations).await; - if let Err(err) = result { - warn!("failed to register dynamic capabilities: {err:#?}"); - } - } - } - - async fn shutdown(&self) -> LSPResult<()> { - Ok(()) - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn did_change_configuration(&self, _: DidChangeConfigurationParams) { - self.reload_config().await; - } - - #[tracing::instrument(level = "debug", skip_all)] - async fn did_change_watched_files(&self, params: DidChangeWatchedFilesParams) { - // Invalidate changed cairo files. - self.with_state_mut(|state| { - for change in ¶ms.changes { - if is_cairo_file_path(&change.uri) { - let Some(file) = state.db.file_for_url(&change.uri) else { continue }; - PrivRawFileContentQuery - .in_db_mut(state.db.as_files_group_mut()) - .invalidate(&file); - } - } - }) - .await; - - // Reload workspace if a config file has changed. - for change in params.changes { - let changed_file_path = change.uri.to_file_path().unwrap_or_default(); - let changed_file_name = changed_file_path.file_name().unwrap_or_default(); - // TODO(pmagiera): react to Scarb.lock. Keep in mind Scarb does save Scarb.lock on each - // metadata call, so it is easy to fall in a loop here. - if ["Scarb.toml", "cairo_project.toml"].map(Some).contains(&changed_file_name.to_str()) - { - self.reload().await.ok(); - } - } - } - - #[tracing::instrument(level = "debug", skip_all, fields(command = params.command))] - async fn execute_command(&self, params: ExecuteCommandParams) -> LSPResult> { - let command = ServerCommands::try_from(params.command); - if let Ok(cmd) = command { - match cmd { - ServerCommands::Reload => { - self.reload().await?; - } - } - } - - match self.client.apply_edit(WorkspaceEdit::default()).await { - Ok(res) if res.applied => self.client.log_message(MessageType::INFO, "applied").await, - Ok(_) => self.client.log_message(MessageType::INFO, "rejected").await, - Err(err) => self.client.log_message(MessageType::ERROR, err).await, - } - - Ok(None) - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_open(&self, params: DidOpenTextDocumentParams) { - let refresh = state_mut_async! {state, self, - let uri = params.text_document.uri; - - // Try to detect the crate for physical files. - // The crate for virtual files is already known. - if uri.scheme() == "file" { - let Ok(path) = uri.to_file_path() else { return false }; - self.detect_crate_for(&mut state.db, &state.config, path).await; - } - - let Some(file_id) = state.db.file_for_url(&uri) else { return false }; - state.open_files.insert(uri); - state.db.override_file_content(file_id, Some(params.text_document.text.into())); - - true - } - .await; - - if refresh { - self.refresh_diagnostics().await.ok(); - } - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_change(&self, params: DidChangeTextDocumentParams) { - let text = if let Ok([TextDocumentContentChangeEvent { text, .. }]) = - TryInto::<[_; 1]>::try_into(params.content_changes) - { - text - } else { - error!("unexpected format of document change"); - return; - }; - let refresh = self - .with_state_mut(|state| { - let uri = params.text_document.uri; - let Some(file) = state.db.file_for_url(&uri) else { return false }; - state.db.override_file_content(file, Some(text.into())); - - true - }) - .await; - - if refresh { - self.refresh_diagnostics().await.ok(); - } - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_save(&self, params: DidSaveTextDocumentParams) { - self.with_state_mut(|state| { - let Some(file) = state.db.file_for_url(¶ms.text_document.uri) else { return }; - PrivRawFileContentQuery.in_db_mut(state.db.as_files_group_mut()).invalidate(&file); - state.db.override_file_content(file, None); - }) - .await; - } - - #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] - async fn did_close(&self, params: DidCloseTextDocumentParams) { - let refresh = self - .with_state_mut(|state| { - state.open_files.remove(¶ms.text_document.uri); - let Some(file) = state.db.file_for_url(¶ms.text_document.uri) else { - return false; - }; - state.db.override_file_content(file, None); - - true - }) - .await; - - if refresh { - self.refresh_diagnostics().await.ok(); - } - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn completion(&self, params: CompletionParams) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::completion::complete(params, &db)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn semantic_tokens_full( - &self, - params: SemanticTokensParams, - ) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::semantic_highlighting::semantic_highlight_full(params, &db)) - .await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn formatting( - &self, - params: DocumentFormattingParams, - ) -> LSPResult>> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::formatter::format(params, &db)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn hover(&self, params: HoverParams) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::hover::hover(params, &db)).await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn goto_definition( - &self, - params: GotoDefinitionParams, - ) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::navigation::goto_definition::goto_definition(params, &db)) - .await - } - - #[tracing::instrument(level = "trace", skip_all)] - async fn code_action(&self, params: CodeActionParams) -> LSPResult> { - let db = self.db_snapshot().await; - self.catch_panics(move || ide::code_actions::code_actions(params, &db)).await - } -} - -fn is_cairo_file_path(file_path: &Url) -> bool { - file_path.path().ends_with(".cairo") -} diff --git a/crates/cairo-lang-language-server/src/lsp/ext.rs b/crates/cairo-lang-language-server/src/lsp/ext.rs index 84366833d4f..5cfe59484cf 100644 --- a/crates/cairo-lang-language-server/src/lsp/ext.rs +++ b/crates/cairo-lang-language-server/src/lsp/ext.rs @@ -1,9 +1,9 @@ //! CairoLS extensions to the Language Server Protocol. +use lsp_types::notification::Notification; +use lsp_types::request::Request; +use lsp_types::{TextDocumentPositionParams, Url}; use serde::{Deserialize, Serialize}; -use tower_lsp::lsp_types::notification::Notification; -use tower_lsp::lsp_types::request::Request; -use tower_lsp::lsp_types::{TextDocumentPositionParams, Url}; /// Provides content of virtual file from the database. pub struct ProvideVirtualFile; diff --git a/crates/cairo-lang-language-server/src/server/api/mod.rs b/crates/cairo-lang-language-server/src/server/api/mod.rs new file mode 100644 index 00000000000..4576fa61d91 --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/api/mod.rs @@ -0,0 +1,231 @@ +use std::fmt; + +use lsp_server::{ErrorCode, ExtractError, Notification, Request, RequestId}; +use lsp_types::notification::{ + Cancel, DidChangeConfiguration, DidChangeTextDocument, DidChangeWatchedFiles, + DidCloseTextDocument, DidOpenTextDocument, DidSaveTextDocument, + Notification as NotificationTrait, +}; +use lsp_types::request::{ + CodeActionRequest, Completion, ExecuteCommand, Formatting, GotoDefinition, HoverRequest, + Request as RequestTrait, SemanticTokensFullRequest, +}; +use tracing::{error, warn}; + +use super::client::Responder; +use super::schedule::BackgroundSchedule; +use crate::lsp::ext::{ExpandMacro, ProvideVirtualFile, ViewAnalyzedCrates}; +use crate::server::schedule::Task; +use crate::state::State; +use crate::Backend; + +pub mod traits; + +pub(crate) fn request<'a>(request: Request) -> Task<'a> { + let id = request.id.clone(); + + match request.method.as_str() { + CodeActionRequest::METHOD => background_request_task::( + request, + BackgroundSchedule::LatencySensitive, + ), + Completion::METHOD => { + background_request_task::(request, BackgroundSchedule::LatencySensitive) + } + ExecuteCommand::METHOD => local_request_task::(request), + ExpandMacro::METHOD => { + background_request_task::(request, BackgroundSchedule::Worker) + } + Formatting::METHOD => { + background_request_task::(request, BackgroundSchedule::LatencySensitive) + } + GotoDefinition::METHOD => { + background_request_task::(request, BackgroundSchedule::LatencySensitive) + } + HoverRequest::METHOD => { + background_request_task::(request, BackgroundSchedule::LatencySensitive) + } + ProvideVirtualFile::METHOD => background_request_task::( + request, + BackgroundSchedule::LatencySensitive, + ), + SemanticTokensFullRequest::METHOD => background_request_task::( + request, + BackgroundSchedule::Worker, + ), + ViewAnalyzedCrates::METHOD => { + background_request_task::(request, BackgroundSchedule::Worker) + } + + method => { + warn!("received request {method} which does not have a handler"); + return Task::nothing(); + } + } + .unwrap_or_else(|error| { + error!("encountered error when routing request with ID {id}: {error:?}"); + let result: Result<(), LSPError> = Err(error); + Task::immediate(id, result) + }) +} + +pub(crate) fn notification<'a>(notification: Notification) -> Task<'a> { + match notification.method.as_str() { + Cancel::METHOD => local_notification_task::(notification), + DidChangeTextDocument::METHOD => { + local_notification_task::(notification) + } + DidChangeConfiguration::METHOD => { + local_notification_task::(notification) + } + DidChangeWatchedFiles::METHOD => { + local_notification_task::(notification) + } + DidCloseTextDocument::METHOD => { + local_notification_task::(notification) + } + DidOpenTextDocument::METHOD => local_notification_task::(notification), + DidSaveTextDocument::METHOD => local_notification_task::(notification), + method => { + warn!("received notification {method} which does not have a handler"); + + return Task::nothing(); + } + } + .unwrap_or_else(|error| { + error!("encountered error when routing notification: {error}"); + + Task::nothing() + }) +} + +fn local_request_task<'a, R: traits::SyncRequestHandler>( + request: Request, +) -> Result, LSPError> { + let (id, params) = cast_request::(request)?; + Ok(Task::local(move |state, notifier, requester, responder| { + let result = R::run(state, notifier, requester, params); + respond::(id, result, &responder); + })) +} + +fn background_request_task<'a, R: traits::BackgroundDocumentRequestHandler>( + request: Request, + schedule: BackgroundSchedule, +) -> Result, LSPError> { + let (id, params) = cast_request::(request)?; + Ok(Task::background(schedule, move |state: &State| { + let state_snapshot = state.snapshot(); + Box::new(move |notifier, responder| { + let result = + Backend::catch_panics(|| R::run_with_snapshot(state_snapshot, notifier, params)) + .and_then(|res| res); + respond::(id, result, &responder); + }) + })) +} + +fn local_notification_task<'a, N: traits::SyncNotificationHandler>( + notification: Notification, +) -> Result, LSPError> { + let (id, params) = cast_notification::(notification)?; + Ok(Task::local(move |session, notifier, requester, _| { + if let Err(err) = N::run(session, notifier, requester, params) { + error!("an error occurred while running {id}: {err}"); + } + })) +} + +/// Tries to cast a serialized request from the server into +/// a parameter type for a specific request handler. +/// It is *highly* recommended to not override this function in your +/// implementation. +fn cast_request(request: Request) -> Result<(RequestId, R::Params), LSPError> { + request + .extract(R::METHOD) + .map_err(|error| match error { + json_error @ ExtractError::JsonError { .. } => { + anyhow::anyhow!("JSON parsing failure:\n{json_error}") + } + ExtractError::MethodMismatch(_) => { + unreachable!( + "a method mismatch should not be possible here unless you've used a different \ + handler (`R`) than the one whose method name was matched against earlier" + ) + } + }) + .with_failure_code(ErrorCode::InternalError) +} + +/// Sends back a response to the lsp_server using a [`Responder`]. +fn respond(id: RequestId, result: LSPResult, responder: &Responder) { + if let Err(err) = &result { + error!("an error occurred with result ID {id}: {err}"); + } + if let Err(err) = responder.respond(id, result) { + error!("failed to send response: {err}"); + } +} + +/// Tries to cast a serialized request from the lsp_server into +/// a parameter type for a specific request handler. +fn cast_notification( + notification: Notification, +) -> Result<(&'static str, N::Params), LSPError> { + Ok(( + N::METHOD, + notification + .extract(N::METHOD) + .map_err(|error| match error { + json_error @ ExtractError::JsonError { .. } => { + anyhow::anyhow!("JSON parsing failure:\n{json_error}") + } + ExtractError::MethodMismatch(_) => { + unreachable!( + "a method mismatch should not be possible here unless you've used a \ + different handler (`N`) than the one whose method name was matched \ + against earlier" + ) + } + }) + .with_failure_code(ErrorCode::InternalError)?, + )) +} + +pub(crate) struct LSPError { + pub(crate) code: ErrorCode, + pub(crate) error: anyhow::Error, +} + +pub type LSPResult = Result; + +/// A trait to convert result types into the lsp_server result type, [`LSPResult`]. +pub trait LSPResultEx { + fn with_failure_code(self, code: ErrorCode) -> Result; +} + +impl> LSPResultEx for Result { + fn with_failure_code(self, code: ErrorCode) -> Result { + self.map_err(|error| LSPError::new(error.into(), code)) + } +} + +impl LSPError { + pub(crate) fn new(error: anyhow::Error, code: ErrorCode) -> Self { + Self { code, error } + } +} + +// Right now, we treat the error code as invisible data that won't +// be printed. +impl fmt::Debug for LSPError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error.fmt(f) + } +} + +impl fmt::Display for LSPError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error.fmt(f) + } +} diff --git a/crates/cairo-lang-language-server/src/server/api/traits.rs b/crates/cairo-lang-language-server/src/server/api/traits.rs new file mode 100644 index 00000000000..1c824cebeeb --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/api/traits.rs @@ -0,0 +1,352 @@ +//! A stateful LSP implementation that calls into the LS API. + +use cairo_lang_filesystem::db::{ + AsFilesGroupMut, FilesGroup, FilesGroupEx, PrivRawFileContentQuery, +}; +use lsp_types::notification::{ + Cancel, DidChangeConfiguration, DidChangeTextDocument, DidChangeWatchedFiles, + DidCloseTextDocument, DidOpenTextDocument, DidSaveTextDocument, Notification, +}; +use lsp_types::request::{ + CodeActionRequest, Completion, ExecuteCommand, Formatting, GotoDefinition, HoverRequest, + Request, SemanticTokensFullRequest, +}; +use lsp_types::{ + CancelParams, CodeActionParams, CodeActionResponse, CompletionParams, CompletionResponse, + DidChangeConfigurationParams, DidChangeTextDocumentParams, DidChangeWatchedFilesParams, + DidCloseTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams, + DocumentFormattingParams, ExecuteCommandParams, GotoDefinitionParams, GotoDefinitionResponse, + Hover, HoverParams, SemanticTokensParams, SemanticTokensResult, TextDocumentContentChangeEvent, + TextDocumentPositionParams, TextEdit, Url, +}; +use serde_json::Value; +use tracing::{error, warn}; + +use crate::lang::lsp::LsProtoGroup; +use crate::lsp::ext::{ + ExpandMacro, ProvideVirtualFile, ProvideVirtualFileRequest, ProvideVirtualFileResponse, + ViewAnalyzedCrates, +}; +use crate::server::api::{LSPError, LSPResult}; +use crate::server::client::{Notifier, Requester}; +use crate::server::commands::ServerCommands; +use crate::state::{State, StateSnapshot}; +use crate::{ide, lang, Backend}; + +/// A request handler that needs mutable access to the session. +/// This will block the main message receiver loop, meaning that no +/// incoming requests or notifications will be handled while `run` is +/// executing. Try to avoid doing any I/O or long-running computations. +pub trait SyncRequestHandler: Request { + fn run( + state: &mut State, + notifier: Notifier, + requester: &mut Requester<'_>, + params: ::Params, + ) -> LSPResult<::Result>; +} + +/// A request handler that can be run on a background thread. +pub trait BackgroundDocumentRequestHandler: Request { + fn run_with_snapshot( + snapshot: StateSnapshot, + notifier: Notifier, + params: ::Params, + ) -> LSPResult<::Result>; +} + +/// A notification handler that needs mutable access to the session. +/// This will block the main message receiver loop, meaning that no +/// incoming requests or notifications will be handled while `run` is +/// executing. Try to avoid doing any I/O or long-running computations. +pub trait SyncNotificationHandler: Notification { + fn run( + state: &mut State, + notifier: Notifier, + requester: &mut Requester<'_>, + params: ::Params, + ) -> LSPResult<()>; +} + +impl BackgroundDocumentRequestHandler for CodeActionRequest { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: CodeActionParams, + ) -> Result, LSPError> { + Ok(ide::code_actions::code_actions(params, &snapshot.db)) + } +} + +impl SyncRequestHandler for ExecuteCommand { + #[tracing::instrument(level = "debug", skip_all, fields(command = params.command))] + fn run( + state: &mut State, + notifier: Notifier, + requester: &mut Requester<'_>, + params: ExecuteCommandParams, + ) -> LSPResult> { + let command = ServerCommands::try_from(params.command); + + if let Ok(cmd) = command { + match cmd { + ServerCommands::Reload => { + Backend::reload(state, ¬ifier, requester)?; + } + } + } + + Ok(None) + } +} + +impl BackgroundDocumentRequestHandler for HoverRequest { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: HoverParams, + ) -> LSPResult> { + Ok(ide::hover::hover(params, &snapshot.db)) + } +} + +impl BackgroundDocumentRequestHandler for Formatting { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: DocumentFormattingParams, + ) -> LSPResult>> { + Ok(ide::formatter::format(params, &snapshot.db)) + } +} + +impl SyncNotificationHandler for Cancel { + #[tracing::instrument(level = "trace", skip_all)] + fn run( + _state: &mut State, + _notifier: Notifier, + _requester: &mut Requester<'_>, + _params: CancelParams, + ) -> LSPResult<()> { + Ok(()) + } +} + +impl SyncNotificationHandler for DidChangeTextDocument { + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + fn run( + state: &mut State, + notifier: Notifier, + _requester: &mut Requester<'_>, + params: DidChangeTextDocumentParams, + ) -> LSPResult<()> { + let text = if let Ok([TextDocumentContentChangeEvent { text, .. }]) = + TryInto::<[_; 1]>::try_into(params.content_changes) + { + text + } else { + error!("unexpected format of document change"); + return Ok(()); + }; + + if let Some(file) = state.db.file_for_url(¶ms.text_document.uri) { + state.db.override_file_content(file, Some(text.into())); + Backend::refresh_diagnostics(state, ¬ifier)?; + }; + + Ok(()) + } +} + +impl SyncNotificationHandler for DidChangeConfiguration { + #[tracing::instrument(level = "debug", skip_all)] + fn run( + state: &mut State, + _notifier: Notifier, + requester: &mut Requester<'_>, + _params: DidChangeConfigurationParams, + ) -> LSPResult<()> { + Backend::reload_config(state, requester) + } +} + +impl SyncNotificationHandler for DidChangeWatchedFiles { + #[tracing::instrument(level = "debug", skip_all)] + fn run( + state: &mut State, + notifier: Notifier, + requester: &mut Requester<'_>, + params: DidChangeWatchedFilesParams, + ) -> LSPResult<()> { + // Invalidate changed cairo files. + for change in ¶ms.changes { + if is_cairo_file_path(&change.uri) { + let Some(file) = state.db.file_for_url(&change.uri) else { continue }; + PrivRawFileContentQuery.in_db_mut(state.db.as_files_group_mut()).invalidate(&file); + } + } + + // Reload workspace if a config file has changed. + for change in params.changes { + let changed_file_path = change.uri.to_file_path().unwrap_or_default(); + let changed_file_name = changed_file_path.file_name().unwrap_or_default(); + // TODO(pmagiera): react to Scarb.lock. Keep in mind Scarb does save Scarb.lock on each + // metadata call, so it is easy to fall in a loop here. + if ["Scarb.toml", "cairo_project.toml"].map(Some).contains(&changed_file_name.to_str()) + { + Backend::reload(state, ¬ifier, requester)?; + } + } + + Ok(()) + } +} + +impl SyncNotificationHandler for DidCloseTextDocument { + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + fn run( + state: &mut State, + notifier: Notifier, + _requester: &mut Requester<'_>, + params: DidCloseTextDocumentParams, + ) -> LSPResult<()> { + state.open_files.remove(¶ms.text_document.uri); + if let Some(file) = state.db.file_for_url(¶ms.text_document.uri) { + state.db.override_file_content(file, None); + Backend::refresh_diagnostics(state, ¬ifier)?; + } + + Ok(()) + } +} + +impl SyncNotificationHandler for DidOpenTextDocument { + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + fn run( + state: &mut State, + notifier: Notifier, + _requester: &mut Requester<'_>, + params: DidOpenTextDocumentParams, + ) -> LSPResult<()> { + let uri = params.text_document.uri; + + // Try to detect the crate for physical files. + // The crate for virtual files is already known. + if uri.scheme() == "file" { + let Ok(path) = uri.to_file_path() else { return Ok(()) }; + + Backend::detect_crate_for( + &mut state.db, + &state.scarb_toolchain, + &state.config, + &path, + ¬ifier, + ); + } + + if let Some(file_id) = state.db.file_for_url(&uri) { + state.open_files.insert(uri); + state.db.override_file_content(file_id, Some(params.text_document.text.into())); + + Backend::refresh_diagnostics(state, ¬ifier)?; + } + + Ok(()) + } +} + +impl SyncNotificationHandler for DidSaveTextDocument { + #[tracing::instrument(level = "debug", skip_all, fields(uri = %params.text_document.uri))] + fn run( + state: &mut State, + _notifier: Notifier, + _requester: &mut Requester<'_>, + params: DidSaveTextDocumentParams, + ) -> LSPResult<()> { + if let Some(file) = state.db.file_for_url(¶ms.text_document.uri) { + PrivRawFileContentQuery.in_db_mut(state.db.as_files_group_mut()).invalidate(&file); + state.db.override_file_content(file, None); + } + + Ok(()) + } +} + +impl BackgroundDocumentRequestHandler for GotoDefinition { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: GotoDefinitionParams, + ) -> LSPResult> { + Ok(ide::navigation::goto_definition::goto_definition(params, &snapshot.db)) + } +} + +impl BackgroundDocumentRequestHandler for Completion { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: CompletionParams, + ) -> LSPResult> { + Ok(ide::completion::complete(params, &snapshot.db)) + } +} + +impl BackgroundDocumentRequestHandler for SemanticTokensFullRequest { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: SemanticTokensParams, + ) -> LSPResult> { + Ok(ide::semantic_highlighting::semantic_highlight_full(params, &snapshot.db)) + } +} + +impl BackgroundDocumentRequestHandler for ProvideVirtualFile { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: ProvideVirtualFileRequest, + ) -> LSPResult { + let content = snapshot + .db + .file_for_url(¶ms.uri) + .and_then(|file_id| snapshot.db.file_content(file_id)) + .map(|content| content.to_string()); + + Ok(ProvideVirtualFileResponse { content }) + } +} + +impl BackgroundDocumentRequestHandler for ViewAnalyzedCrates { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + _params: (), + ) -> LSPResult { + Ok(lang::inspect::crates::inspect_analyzed_crates(&snapshot.db)) + } +} + +impl BackgroundDocumentRequestHandler for ExpandMacro { + #[tracing::instrument(level = "trace", skip_all)] + fn run_with_snapshot( + snapshot: StateSnapshot, + _notifier: Notifier, + params: TextDocumentPositionParams, + ) -> LSPResult> { + Ok(ide::macros::expand::expand_macro(&snapshot.db, ¶ms)) + } +} + +fn is_cairo_file_path(file_path: &Url) -> bool { + file_path.path().ends_with(".cairo") +} diff --git a/crates/cairo-lang-language-server/src/server/client.rs b/crates/cairo-lang-language-server/src/server/client.rs new file mode 100644 index 00000000000..d0e67cdf57b --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/client.rs @@ -0,0 +1,153 @@ +use std::any::TypeId; + +use anyhow::Result; +use lsp_server::{Notification, RequestId, Response}; +use lsp_types::notification::Notification as NotificationTrait; +use rustc_hash::FxHashMap; +use serde_json::Value; +use tracing::error; + +use super::schedule::Task; +use crate::server::api::LSPError; +use crate::server::connection::ClientSender; + +type ResponseBuilder<'s> = Box Task<'s>>; + +pub struct Client<'s> { + notifier: Notifier, + responder: Responder, + pub(super) requester: Requester<'s>, +} + +#[derive(Clone)] +pub struct Notifier(ClientSender); + +#[derive(Clone)] +pub struct Responder(ClientSender); + +pub struct Requester<'s> { + sender: ClientSender, + next_request_id: i32, + response_handlers: FxHashMap>, +} + +impl<'s> Client<'s> { + pub fn new(sender: ClientSender) -> Self { + Self { + notifier: Notifier(sender.clone()), + responder: Responder(sender.clone()), + requester: Requester { + sender, + next_request_id: 1, + response_handlers: FxHashMap::default(), + }, + } + } + + pub fn notifier(&self) -> Notifier { + self.notifier.clone() + } + + pub fn responder(&self) -> Responder { + self.responder.clone() + } +} + +impl Notifier { + pub fn notify(&self, params: N::Params) { + let method = N::METHOD; + + let message = + lsp_server::Message::Notification(Notification::new(method.to_string(), params)); + + if let Err(err) = self.0.send(message) { + error!("failed to send `{method}` notification: {err:?}") + } + } +} + +impl Responder { + pub fn respond(&self, id: RequestId, result: Result) -> Result<()> + where + R: serde::Serialize, + { + self.0.send( + match result { + Ok(res) => Response::new_ok(id, res), + Err(LSPError { code, error }) => { + Response::new_err(id, code as i32, format!("{error}")) + } + } + .into(), + ) + } +} + +impl<'s> Requester<'s> { + /// Sends a request of kind `R` to the client, with associated parameters. + /// The task provided by `response_handler` will be dispatched as soon as the response + /// comes back from the client. + pub fn request( + &mut self, + params: R::Params, + response_handler: impl Fn(R::Result) -> Task<'s> + 'static, + ) -> Result<()> + where + R: lsp_types::request::Request, + { + let serialized_params = serde_json::to_value(params)?; + + self.response_handlers.insert( + self.next_request_id.into(), + Box::new(move |response: lsp_server::Response| { + match (response.error, response.result) { + (Some(err), _) => { + error!("got an error from the client (code {}): {}", err.code, err.message); + Task::nothing() + } + (None, Some(response)) => match serde_json::from_value(response) { + Ok(response) => response_handler(response), + Err(error) => { + error!("failed to deserialize response from server: {error}"); + Task::nothing() + } + }, + (None, None) => { + if TypeId::of::() == TypeId::of::<()>() { + // We can't call `response_handler(())` directly here, but + // since we _know_ the type expected is `()`, we can use + // `from_value(Value::Null)`. `R::Result` implements `DeserializeOwned`, + // so this branch works in the general case but we'll only + // hit it if the concrete type is `()`, so the `unwrap()` is safe here. + response_handler(serde_json::from_value(Value::Null).unwrap()); + } else { + error!( + "server response was invalid: did not contain a result or error" + ); + } + Task::nothing() + } + } + }), + ); + + self.sender.send(lsp_server::Message::Request(lsp_server::Request { + id: self.next_request_id.into(), + method: R::METHOD.into(), + params: serialized_params, + }))?; + + self.next_request_id += 1; + + Ok(()) + } + + pub fn pop_response_task(&mut self, response: Response) -> Task<'s> { + if let Some(handler) = self.response_handlers.remove(&response.id) { + handler(response) + } else { + error!("received a response with ID {}, which was not expected", response.id); + Task::nothing() + } + } +} diff --git a/crates/cairo-lang-language-server/src/server/connection.rs b/crates/cairo-lang-language-server/src/server/connection.rs new file mode 100644 index 00000000000..142ec47762e --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/connection.rs @@ -0,0 +1,144 @@ +use std::sync::{Arc, Weak}; + +use anyhow::{bail, Result}; +use lsp_server::{ + Connection as LSPConnection, IoThreads, Message, Notification, Request, RequestId, Response, +}; +use lsp_types::notification::{Exit, Notification as NotificationTrait}; +use lsp_types::request::{Request as RequestTrait, Shutdown}; +use lsp_types::{InitializeResult, ServerCapabilities}; +use tracing::{error, info}; + +type ConnectionSender = crossbeam::channel::Sender; +type ConnectionReceiver = crossbeam::channel::Receiver; + +/// A builder for `Connection` that handles LSP initialization. +pub struct ConnectionInitializer { + connection: LSPConnection, + /// None in tests, Some(_) otherwise + threads: Option, +} + +/// Handles inbound and outbound messages with the client. +pub struct Connection { + sender: Arc, + receiver: ConnectionReceiver, + /// None in tests, Some(_) otherwise + threads: Option, +} + +impl ConnectionInitializer { + /// Create a new LSP server connection over stdin/stdout. + pub fn stdio() -> Self { + let (connection, threads) = LSPConnection::stdio(); + Self { connection, threads: Some(threads) } + } + + #[cfg(feature = "testing")] + /// Create a new LSP server connection in memory. + pub fn memory() -> (Self, LSPConnection) { + let (server, client) = LSPConnection::memory(); + (Self { connection: server, threads: None }, client) + } + + /// Starts the initialization process with the client by listening for an initialization + /// request. Returns a request ID that should be passed into `initialize_finish` later, + /// along with the initialization parameters that were provided. + pub fn initialize_start(&self) -> Result<(RequestId, lsp_types::InitializeParams)> { + let (id, params) = self.connection.initialize_start()?; + Ok((id, serde_json::from_value(params)?)) + } + + /// Finishes the initialization process with the client, + /// returning an initialized `Connection`. + pub fn initialize_finish( + self, + id: RequestId, + server_capabilities: ServerCapabilities, + ) -> Result { + let initialize_result = + InitializeResult { capabilities: server_capabilities, server_info: None }; + self.connection.initialize_finish(id, serde_json::to_value(initialize_result).unwrap())?; + let Self { connection: LSPConnection { sender, receiver }, threads } = self; + Ok(Connection { sender: Arc::new(sender), receiver, threads }) + } +} + +impl Connection { + /// Make a new `ClientSender` for sending messages to the client. + pub fn make_sender(&self) -> ClientSender { + ClientSender { weak_sender: Arc::downgrade(&self.sender) } + } + + /// An iterator over incoming messages from the client. + pub fn incoming(&self) -> crossbeam::channel::Iter<'_, Message> { + self.receiver.iter() + } + + /// Check and respond to any incoming shutdown requests; returns `true` if the server should be + /// shutdown. + pub fn handle_shutdown(&self, message: &Message) -> Result { + match message { + Message::Request(Request { id, method, .. }) if method == Shutdown::METHOD => { + self.sender.send(Response::new_ok(id.clone(), ()).into())?; + info!("shutdown request received, waiting for an exit notification..."); + match self.receiver.recv_timeout(std::time::Duration::from_secs(30))? { + Message::Notification(Notification { method, .. }) + if method == Exit::METHOD => + { + info!("exit notification received, server shutting down..."); + Ok(true) + } + message => bail!( + "server received unexpected message {message:?} while waiting for exit \ + notification" + ), + } + } + Message::Notification(Notification { method, .. }) if method == Exit::METHOD => { + error!( + "server received an exit notification before a shutdown request was sent, \ + exiting..." + ); + Ok(true) + } + _ => Ok(false), + } + } + + /// Join the I/O threads that underpin this connection. + /// This is guaranteed to be nearly immediate since + /// we close the only active channels to these threads prior + /// to joining them. + pub fn close(self) -> Result<()> { + drop( + Arc::into_inner(self.sender) + .expect("the client sender shouldn't have more than one strong reference"), + ); + drop(self.receiver); + + if let Some(threads) = self.threads { + threads.join()?; + } + Ok(()) + } +} + +/// A weak reference to an underlying sender channel, used for communication with the client. +/// If the `Connection` that created this `ClientSender` is dropped, any `send` calls will throw +/// an error. +#[derive(Clone, Debug)] +pub struct ClientSender { + weak_sender: Weak, +} + +// note: additional wrapper functions for senders may be implemented as needed. +impl ClientSender { + pub(crate) fn send(&self, msg: Message) -> Result<()> { + let Some(sender) = self.weak_sender.upgrade() else { + bail!("the connection with the client has been closed"); + }; + + Ok(sender.send(msg)?) + } +} diff --git a/crates/cairo-lang-language-server/src/server/mod.rs b/crates/cairo-lang-language-server/src/server/mod.rs index a5e45b9efc8..4756cd77856 100644 --- a/crates/cairo-lang-language-server/src/server/mod.rs +++ b/crates/cairo-lang-language-server/src/server/mod.rs @@ -1,2 +1,5 @@ +pub mod api; +pub mod client; pub mod commands; -pub mod notifier; +pub mod connection; +pub mod schedule; diff --git a/crates/cairo-lang-language-server/src/server/notifier.rs b/crates/cairo-lang-language-server/src/server/notifier.rs deleted file mode 100644 index 6234dc2e784..00000000000 --- a/crates/cairo-lang-language-server/src/server/notifier.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::fmt; - -use tokio::runtime::Handle; -use tower_lsp::lsp_types::notification::Notification; -use tower_lsp::Client; - -/// A minimal interface for sending notifications to the language client synchronously. -/// -/// This object is small and cheap to clone, so it can be passed around freely. -#[derive(Clone)] -pub struct Notifier { - /// The language client handle to which notifications will be sent. - client: Client, -} - -impl Notifier { - /// Constructs a new [`Notifier`]. - pub fn new(client: &Client) -> Self { - Notifier { client: client.clone() } - } - - /// Sends a custom notification to the client. - pub fn send_notification(&self, params: N::Params) - where - N: Notification, - { - Handle::current().block_on(self.client.send_notification::(params)) - } -} - -impl fmt::Debug for Notifier { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Notifier") - } -} diff --git a/crates/cairo-lang-language-server/src/server/schedule/mod.rs b/crates/cairo-lang-language-server/src/server/schedule/mod.rs new file mode 100644 index 00000000000..79cc98b0800 --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/schedule/mod.rs @@ -0,0 +1,91 @@ +use std::num::NonZeroUsize; + +use anyhow::Result; + +pub mod task; +pub mod thread; + +pub use task::{BackgroundSchedule, Task}; + +use self::task::{BackgroundTaskBuilder, SyncTask}; +use self::thread::ThreadPriority; +use super::client::Client; +use crate::server::connection::ClientSender; +use crate::state::State; + +/// The event loop thread is actually a secondary thread that we spawn from the +/// _actual_ main thread. This secondary thread has a larger stack size +/// than some OS defaults (Windows, for example) and is also designated as +/// high-priority. +pub(crate) fn event_loop_thread( + func: impl FnOnce() -> Result<()> + Send + 'static, +) -> Result>> { + // Override OS defaults to avoid stack overflows on platforms with low stack size defaults. + const MAIN_THREAD_STACK_SIZE: usize = 2 * 1024 * 1024; + const MAIN_THREAD_NAME: &str = "cairols:main"; + Ok(thread::Builder::new(ThreadPriority::LatencySensitive) + .name(MAIN_THREAD_NAME.into()) + .stack_size(MAIN_THREAD_STACK_SIZE) + .spawn(func)?) +} + +pub(crate) struct Scheduler<'s> { + state: &'s mut State, + client: Client<'s>, + background_pool: thread::Pool, +} + +impl<'s> Scheduler<'s> { + pub fn new(state: &'s mut State, worker_threads: NonZeroUsize, sender: ClientSender) -> Self { + Self { + state, + background_pool: thread::Pool::new(worker_threads), + client: Client::new(sender), + } + } + + /// Immediately sends a request of kind `R` to the client, with associated parameters. + /// The task provided by `response_handler` will be dispatched as soon as the response + /// comes back from the client. + pub(crate) fn request( + &mut self, + params: R::Params, + response_handler: impl Fn(R::Result) -> Task<'s> + 'static, + ) -> Result<()> + where + R: lsp_types::request::Request, + { + self.client.requester.request::(params, response_handler) + } + + /// Creates a task to handle a response from the client. + pub(crate) fn response(&mut self, response: lsp_server::Response) -> Task<'s> { + self.client.requester.pop_response_task(response) + } + + /// Dispatches a `task` by either running it as a blocking function or + /// executing it on a background thread pool. + pub(crate) fn dispatch(&mut self, task: Task<'s>) { + match task { + Task::Sync(SyncTask { func }) => { + let notifier = self.client.notifier(); + let responder = self.client.responder(); + func(self.state, notifier, &mut self.client.requester, responder); + } + Task::Background(BackgroundTaskBuilder { schedule, builder: func }) => { + let static_func = func(self.state); + let notifier = self.client.notifier(); + let responder = self.client.responder(); + let task = move || static_func(notifier, responder); + match schedule { + BackgroundSchedule::Worker => { + self.background_pool.spawn(ThreadPriority::Worker, task); + } + BackgroundSchedule::LatencySensitive => { + self.background_pool.spawn(ThreadPriority::LatencySensitive, task) + } + } + } + } + } +} diff --git a/crates/cairo-lang-language-server/src/server/schedule/task.rs b/crates/cairo-lang-language-server/src/server/schedule/task.rs new file mode 100644 index 00000000000..d895cae6941 --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/schedule/task.rs @@ -0,0 +1,89 @@ +use lsp_server::RequestId; +use serde::Serialize; +use tracing::error; + +use crate::server::api; +use crate::server::client::{Notifier, Requester, Responder}; +use crate::state::State; + +type LocalFn<'s> = Box, Responder) + 's>; + +type BackgroundFn = Box; + +type BackgroundFnBuilder<'s> = Box BackgroundFn + 's>; + +/// Describes how the task should be run. +#[derive(Clone, Copy, Debug, Default)] +pub enum BackgroundSchedule { + /// The task should be run on the general high-priority background + /// thread. + LatencySensitive, + /// The task should be run on a regular-priority background thread. + #[default] + Worker, +} + +/// A [`Task`] is a future that has not yet started, and it is the job of +/// the [`super::Scheduler`] to make that happen, via [`super::Scheduler::dispatch`]. +/// A task can either run on the main thread (in other words, the same thread as the +/// scheduler) or it can run in a background thread. The main difference between +/// the two is that background threads only have a read-only snapshot of the session, +/// while local tasks have exclusive access and can modify it as they please. Keep in mind that +/// local tasks will **block** the main event loop, so only use local tasks if you **need** +/// mutable state access, or you need the absolute lowest latency possible. +pub enum Task<'s> { + Background(BackgroundTaskBuilder<'s>), + Sync(SyncTask<'s>), +} + +// The reason why this isn't just a 'static background closure +// is because we need to take a snapshot of the state before sending +// this task to the background. The inner closure can't take the state +// as an immutable reference since it's used mutably elsewhere. So instead, +// a background task is built using an outer closure that borrows the state to take a snapshot, +// that the inner closure can capture. This builder closure has a lifetime linked to the scheduler. +// When the task is dispatched, the scheduler runs the synchronous builder, which takes the state +// as a reference, to create the inner 'static closure. That closure is then moved to a background +// task pool. +pub struct BackgroundTaskBuilder<'s> { + pub(super) schedule: BackgroundSchedule, + pub(super) builder: BackgroundFnBuilder<'s>, +} + +pub struct SyncTask<'s> { + pub func: LocalFn<'s>, +} + +impl<'s> Task<'s> { + /// Creates a new background task. + pub(crate) fn background( + schedule: BackgroundSchedule, + func: impl FnOnce(&State) -> Box + 's, + ) -> Self { + Self::Background(BackgroundTaskBuilder { schedule, builder: Box::new(func) }) + } + + /// Creates a new local task. + pub(crate) fn local( + func: impl FnOnce(&mut State, Notifier, &mut Requester<'_>, Responder) + 's, + ) -> Self { + Self::Sync(SyncTask { func: Box::new(func) }) + } + + /// Creates a local task that immediately responds with the provided `request`. + pub(crate) fn immediate(id: RequestId, result: Result) -> Self + where + R: Serialize + Send + 'static, + { + Self::local(move |_, _, _, responder| { + if let Err(err) = responder.respond(id, result) { + error!("unable to send immediate response: {err:?}"); + } + }) + } + + /// Creates a local task that does nothing. + pub(crate) fn nothing() -> Self { + Self::local(move |_, _, _, _| {}) + } +} diff --git a/crates/cairo-lang-language-server/src/server/schedule/thread/mod.rs b/crates/cairo-lang-language-server/src/server/schedule/thread/mod.rs new file mode 100644 index 00000000000..93df5b4a0fe --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/schedule/thread/mod.rs @@ -0,0 +1,84 @@ +// +------------------------------------------------------------+ +// | Code adopted from: | +// | Repository: https://github.com/rust-lang/rust-analyzer.git | +// | File: `crates/stdx/src/thread.rs` | +// | Commit: 03b3cb6be9f21c082f4206b35c7fe7f291c94eaa | +// +------------------------------------------------------------+ +//! A utility module for working with threads that automatically joins threads upon drop +//! and abstracts over operating system quality of service (QoS) APIs +//! through the concept of a “thread priority”. +//! +//! The priority of a thread is frozen at thread creation time, +//! i.e. there is no API to change the priority of a thread once it has been spawned. +//! +//! As a system, rust-analyzer should have the property that +//! old manual scheduling APIs are replaced entirely by QoS. +//! To maintain this invariant, we panic when it is clear that +//! old scheduling APIs have been used. +//! +//! Moreover, we also want to ensure that every thread has a priority set explicitly +//! to force a decision about its importance to the system. +//! Thus, [`ThreadPriority`] has no default value +//! and every entry point to creating a thread requires a [`ThreadPriority`] upfront. + +// Keeps us from getting warnings about the word `QoS` +#![allow(clippy::doc_markdown)] + +use std::fmt; + +mod pool; +mod priority; + +pub(super) use pool::Pool; +pub(super) use priority::ThreadPriority; + +pub(super) struct Builder { + priority: ThreadPriority, + inner: jod_thread::Builder, +} + +impl Builder { + pub(super) fn new(priority: ThreadPriority) -> Builder { + Builder { priority, inner: jod_thread::Builder::new() } + } + + pub(super) fn name(self, name: String) -> Builder { + Builder { inner: self.inner.name(name), ..self } + } + + pub(super) fn stack_size(self, size: usize) -> Builder { + Builder { inner: self.inner.stack_size(size), ..self } + } + + pub(super) fn spawn(self, f: F) -> std::io::Result> + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + let inner_handle = self.inner.spawn(move || { + self.priority.apply_to_current_thread(); + f() + })?; + + Ok(JoinHandle { inner: Some(inner_handle) }) + } +} + +pub struct JoinHandle { + // `inner` is an `Option` so that we can + // take ownership of the contained `JoinHandle`. + inner: Option>, +} + +impl JoinHandle { + pub(crate) fn join(mut self) -> T { + self.inner.take().unwrap().join() + } +} + +impl fmt::Debug for JoinHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("JoinHandle { .. }") + } +} diff --git a/crates/cairo-lang-language-server/src/server/schedule/thread/pool.rs b/crates/cairo-lang-language-server/src/server/schedule/thread/pool.rs new file mode 100644 index 00000000000..b9921cfb235 --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/schedule/thread/pool.rs @@ -0,0 +1,90 @@ +// +------------------------------------------------------------+ +// | Code adopted from: | +// | Repository: https://github.com/rust-lang/rust-analyzer.git | +// | File: `crates/stdx/src/thread/pool.rs` | +// | Commit: 03b3cb6be9f21c082f4206b35c7fe7f291c94eaa | +// +------------------------------------------------------------+ +//! [`Pool`] implements a basic custom thread pool +//! inspired by the [`threadpool` crate](http://docs.rs/threadpool). +//! When you spawn a task, you specify a thread priority +//! so the pool can schedule it to run on a thread with that priority. +//! rust-analyzer uses this to prioritize work based on latency requirements. +//! +//! The thread pool is implemented entirely using +//! the threading utilities in [`crate::server::schedule::thread`]. + +use std::num::NonZeroUsize; + +use crossbeam::channel::{Receiver, Sender}; + +use super::{Builder, JoinHandle, ThreadPriority}; + +pub(crate) struct Pool { + // `_handles` is never read: the field is present + // only for its `Drop` impl. + + // The worker threads exit once the channel closes; + // make sure to keep `job_sender` above `handles` + // so that the channel is actually closed + // before we join the worker threads! + job_sender: Sender, + _handles: Vec, +} + +struct Job { + requested_priority: ThreadPriority, + f: Box, +} + +impl Pool { + pub(crate) fn new(threads: NonZeroUsize) -> Pool { + // Override OS defaults to avoid stack overflows on platforms with low stack size defaults. + const STACK_SIZE: usize = 2 * 1024 * 1024; + const INITIAL_PRIORITY: ThreadPriority = ThreadPriority::Worker; + + let threads = usize::from(threads); + + // Channel buffer capacity is between 2 and 4, depending on the pool size. + let (job_sender, job_receiver) = crossbeam::channel::bounded(std::cmp::min(threads * 2, 4)); + + let mut handles = Vec::with_capacity(threads); + for i in 0..threads { + let handle = Builder::new(INITIAL_PRIORITY) + .stack_size(STACK_SIZE) + .name(format!("cairo-ls:worker:{i}")) + .spawn({ + let job_receiver: Receiver = job_receiver.clone(); + move || { + let mut current_priority = INITIAL_PRIORITY; + for job in job_receiver { + if job.requested_priority != current_priority { + job.requested_priority.apply_to_current_thread(); + current_priority = job.requested_priority; + } + (job.f)(); + } + } + }) + .expect("failed to spawn thread"); + + handles.push(handle); + } + + Pool { _handles: handles, job_sender } + } + + pub(crate) fn spawn(&self, priority: ThreadPriority, f: F) + where + F: FnOnce() + Send + 'static, + { + let f = Box::new(move || { + if cfg!(debug_assertions) { + priority.assert_is_used_on_current_thread(); + } + f(); + }); + + let job = Job { requested_priority: priority, f }; + self.job_sender.send(job).unwrap(); + } +} diff --git a/crates/cairo-lang-language-server/src/server/schedule/thread/priority.rs b/crates/cairo-lang-language-server/src/server/schedule/thread/priority.rs new file mode 100644 index 00000000000..b7f8ddc3e7e --- /dev/null +++ b/crates/cairo-lang-language-server/src/server/schedule/thread/priority.rs @@ -0,0 +1,285 @@ +// +------------------------------------------------------------+ +// | Code adopted from: | +// | Repository: https://github.com/rust-lang/rust-analyzer.git | +// | File: `crates/stdx/src/thread/intent.rs` | +// | Commit: 03b3cb6be9f21c082f4206b35c7fe7f291c94eaa | +// +------------------------------------------------------------+ +//! An opaque facade around platform-specific QoS APIs. + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +// Please maintain order from least to most priority for the derived `Ord` impl. +pub(crate) enum ThreadPriority { + /// Any thread which does work that isn't in a critical path. + Worker, + + /// Any thread which does work caused by the user typing, or + /// work that the editor may wait on. + LatencySensitive, +} + +impl ThreadPriority { + // These APIs must remain private; + // we only want consumers to set thread priority + // during thread creation. + + pub(crate) fn apply_to_current_thread(self) { + let class = thread_priority_to_qos_class(self); + set_current_thread_qos_class(class); + } + + pub(crate) fn assert_is_used_on_current_thread(self) { + if IS_QOS_AVAILABLE { + let class = thread_priority_to_qos_class(self); + assert_eq!(get_current_thread_qos_class(), Some(class)); + } + } +} + +use imp::QoSClass; + +const IS_QOS_AVAILABLE: bool = imp::IS_QOS_AVAILABLE; + +fn set_current_thread_qos_class(class: QoSClass) { + imp::set_current_thread_qos_class(class); +} + +fn get_current_thread_qos_class() -> Option { + imp::get_current_thread_qos_class() +} + +fn thread_priority_to_qos_class(priority: ThreadPriority) -> QoSClass { + imp::thread_priority_to_qos_class(priority) +} + +// All Apple platforms use XNU as their kernel +// and thus have the concept of QoS. +#[cfg(target_vendor = "apple")] +mod imp { + use super::ThreadPriority; + + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + // Please maintain order from least to most priority for the derived `Ord` impl. + pub(super) enum QoSClass { + // Documentation adapted from https://github.com/apple-oss-distributions/libpthread/blob/67e155c94093be9a204b69637d198eceff2c7c46/include/sys/qos.h#L55 + /// TLDR: invisible maintenance tasks + /// + /// Contract: + /// + /// * **You do not care about how long it takes for work to finish.** + /// * **You do not care about work being deferred temporarily.** (e.g. if the device's + /// battery is in a critical state) + /// + /// Examples: + /// + /// * in a video editor: creating periodic backups of project files + /// * in a browser: cleaning up cached sites which have not been accessed in a long time + /// * in a collaborative word processor: creating a searchable index of all documents + /// + /// Use this QoS class for background tasks + /// which the user did not initiate themselves + /// and which are invisible to the user. + /// It is expected that this work will take significant time to complete: + /// minutes or even hours. + /// + /// This QoS class provides the most energy and thermally-efficient execution possible. + /// All other work is prioritized over background tasks. + Background, + + /// TLDR: tasks that don't block using your app + /// + /// Contract: + /// + /// * **Your app remains useful even as the task is executing.** + /// + /// Examples: + /// + /// * in a video editor: exporting a video to disk - the user can still work on the timeline + /// * in a browser: automatically extracting a downloaded zip file - the user can still + /// switch tabs + /// * in a collaborative word processor: downloading images embedded in a document - the + /// user can still make edits + /// + /// Use this QoS class for tasks which + /// may or may not be initiated by the user, + /// but whose result is visible. + /// It is expected that this work will take a few seconds to a few minutes. + /// Typically your app will include a progress bar + /// for tasks using this class. + /// + /// This QoS class provides a balance between + /// performance, responsiveness and efficiency. + Utility, + + /// TLDR: tasks that block using your app + /// + /// Contract: + /// + /// * **You need this work to complete before the user can keep interacting with your app.** + /// * **Your work will not take more than a few seconds to complete.** + /// + /// Examples: + /// + /// * in a video editor: opening a saved project + /// * in a browser: loading a list of the user's bookmarks and top sites when a new tab is + /// created + /// * in a collaborative word processor: running a search on the document's content + /// + /// Use this QoS class for tasks which were initiated by the user + /// and block the usage of your app while they are in progress. + /// It is expected that this work will take a few seconds or less to complete; + /// not long enough to cause the user to switch to something else. + /// Your app will likely indicate progress on these tasks + /// through the display of placeholder content or modals. + /// + /// This QoS class is not energy-efficient. + /// Rather, it provides responsiveness + /// by prioritizing work above other tasks on the system + /// except for critical user-interactive work. + UserInitiated, + + /// TLDR: render loops and nothing else + /// + /// Contract: + /// + /// * **You absolutely need this work to complete immediately or your app will appear to + /// freeze.** + /// * **Your work will always complete virtually instantaneously.** + /// + /// Examples: + /// + /// * the main thread in a GUI application + /// * the update & render loop in a game + /// * a secondary thread which progresses an animation + /// + /// Use this QoS class for any work which, if delayed, + /// will make your user interface unresponsive. + /// It is expected that this work will be virtually instantaneous. + /// + /// This QoS class is not energy-efficient. + /// Specifying this class is a request to run with + /// nearly all available system CPU and I/O bandwidth even under contention. + UserInteractive, + } + + pub(super) const IS_QOS_AVAILABLE: bool = true; + + pub(super) fn set_current_thread_qos_class(class: QoSClass) { + let c = match class { + QoSClass::UserInteractive => libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, + QoSClass::UserInitiated => libc::qos_class_t::QOS_CLASS_USER_INITIATED, + QoSClass::Utility => libc::qos_class_t::QOS_CLASS_UTILITY, + QoSClass::Background => libc::qos_class_t::QOS_CLASS_BACKGROUND, + }; + + #[allow(unsafe_code)] + let code = unsafe { libc::pthread_set_qos_class_self_np(c, 0) }; + + if code == 0 { + return; + } + + #[allow(unsafe_code)] + let errno = unsafe { *libc::__error() }; + + match errno { + libc::EPERM => { + // This thread has been excluded from the QoS system + // due to a previous call to a function such as `pthread_setschedparam` + // which is incompatible with QoS. + // + // Panic instead of returning an error + // to maintain the invariant that we only use QoS APIs. + panic!("tried to set QoS of thread which has opted out of QoS (os error {errno})") + } + + libc::EINVAL => { + // This is returned if we pass something other than a qos_class_t + // to `pthread_set_qos_class_self_np`. + // + // This is impossible, so again panic. + unreachable!( + "invalid qos_class_t value was passed to pthread_set_qos_class_self_np" + ) + } + + _ => { + // `pthread_set_qos_class_self_np`’s documentation + // does not mention any other errors. + unreachable!("`pthread_set_qos_class_self_np` returned unexpected error {errno}") + } + } + } + + pub(super) fn get_current_thread_qos_class() -> Option { + #[allow(unsafe_code)] + let current_thread = unsafe { libc::pthread_self() }; + let mut qos_class_raw = libc::qos_class_t::QOS_CLASS_UNSPECIFIED; + #[allow(unsafe_code)] + let code = unsafe { + libc::pthread_get_qos_class_np(current_thread, &mut qos_class_raw, std::ptr::null_mut()) + }; + + if code != 0 { + // `pthread_get_qos_class_np`’s documentation states that + // an error value is placed into errno if the return code is not zero. + // However, it never states what errors are possible. + // Inspecting the source[0] shows that, as of this writing, it always returns zero. + // + // Whatever errors the function could report in future are likely to be + // ones which we cannot handle anyway + // + // 0: https://github.com/apple-oss-distributions/libpthread/blob/67e155c94093be9a204b69637d198eceff2c7c46/src/qos.c#L171-L177 + #[allow(unsafe_code)] + let errno = unsafe { *libc::__error() }; + unreachable!("`pthread_get_qos_class_np` failed unexpectedly (os error {errno})"); + } + + match qos_class_raw { + libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE => Some(QoSClass::UserInteractive), + libc::qos_class_t::QOS_CLASS_USER_INITIATED => Some(QoSClass::UserInitiated), + libc::qos_class_t::QOS_CLASS_DEFAULT => None, // QoS has never been set + libc::qos_class_t::QOS_CLASS_UTILITY => Some(QoSClass::Utility), + libc::qos_class_t::QOS_CLASS_BACKGROUND => Some(QoSClass::Background), + + libc::qos_class_t::QOS_CLASS_UNSPECIFIED => { + // Using manual scheduling APIs causes threads to “opt out” of QoS. + // At this point they become incompatible with QoS, + // and as such have the “unspecified” QoS class. + // + // Panic instead of returning an error + // to maintain the invariant that we only use QoS APIs. + panic!("tried to get QoS of thread which has opted out of QoS") + } + } + } + + pub(super) fn thread_priority_to_qos_class(priority: ThreadPriority) -> QoSClass { + match priority { + ThreadPriority::Worker => QoSClass::Utility, + ThreadPriority::LatencySensitive => QoSClass::UserInitiated, + } + } +} + +// FIXME: Windows has QoS APIs, we should use them! +#[cfg(not(target_vendor = "apple"))] +mod imp { + use super::ThreadPriority; + + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + pub(super) enum QoSClass { + Default, + } + + pub(super) const IS_QOS_AVAILABLE: bool = false; + + pub(super) fn set_current_thread_qos_class(_: QoSClass) {} + + pub(super) fn get_current_thread_qos_class() -> Option { + None + } + + pub(super) fn thread_priority_to_qos_class(_: ThreadPriority) -> QoSClass { + QoSClass::Default + } +} diff --git a/crates/cairo-lang-language-server/src/state.rs b/crates/cairo-lang-language-server/src/state.rs index 5e2217cbde1..ff3a77d1346 100644 --- a/crates/cairo-lang-language-server/src/state.rs +++ b/crates/cairo-lang-language-server/src/state.rs @@ -1,16 +1,19 @@ use std::collections::{HashMap, HashSet}; use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use std::time::{Duration, SystemTime}; use cairo_lang_diagnostics::Diagnostics; use cairo_lang_lowering::diagnostic::LoweringDiagnostic; use cairo_lang_parser::ParserDiagnostic; use cairo_lang_semantic::SemanticDiagnostic; +use lsp_types::{ClientCapabilities, Url}; use salsa::ParallelDatabase; -use tower_lsp::lsp_types::{ClientCapabilities, Url}; use crate::config::Config; use crate::lang::db::AnalysisDatabase; +use crate::toolchain::scarb::ScarbToolchain; +use crate::{env_config, Tricks}; /// State of Language server. pub struct State { @@ -19,6 +22,10 @@ pub struct State { pub file_diagnostics: Owned>, pub config: Owned, pub client_capabilities: Owned, + pub scarb_toolchain: ScarbToolchain, + pub last_replace: SystemTime, + pub db_replace_interval: Duration, + pub tricks: Owned, } #[derive(Clone, Default, PartialEq, Eq)] @@ -36,22 +43,32 @@ impl FileDiagnostics { impl std::panic::UnwindSafe for FileDiagnostics {} impl State { - pub fn new(db: AnalysisDatabase) -> Self { + pub fn new( + db: AnalysisDatabase, + client_capabilities: ClientCapabilities, + scarb_toolchain: ScarbToolchain, + tricks: Tricks, + ) -> Self { Self { db, open_files: Default::default(), file_diagnostics: Default::default(), config: Default::default(), - client_capabilities: Default::default(), + client_capabilities: Owned::new(client_capabilities.into()), + tricks: Owned::new(tricks.into()), + scarb_toolchain, + last_replace: SystemTime::now(), + db_replace_interval: env_config::db_replace_interval(), } } pub fn snapshot(&self) -> StateSnapshot { StateSnapshot { db: self.db.snapshot(), - open_files: self.open_files.snapshot(), - config: self.config.snapshot(), - client_capabilities: self.client_capabilities.snapshot(), + _open_files: self.open_files.snapshot(), + _config: self.config.snapshot(), + _client_capabilities: self.client_capabilities.snapshot(), + _tricks: self.tricks.snapshot(), } } } @@ -59,9 +76,10 @@ impl State { /// Readonly snapshot of Language server state. pub struct StateSnapshot { pub db: salsa::Snapshot, - pub open_files: Snapshot>, - pub config: Snapshot, - pub client_capabilities: Snapshot, + pub _open_files: Snapshot>, + pub _config: Snapshot, + pub _client_capabilities: Snapshot, + pub _tricks: Snapshot, } impl std::panic::UnwindSafe for StateSnapshot {} diff --git a/crates/cairo-lang-language-server/src/toolchain/scarb.rs b/crates/cairo-lang-language-server/src/toolchain/scarb.rs index 445d3974b2d..902bc00155e 100644 --- a/crates/cairo-lang-language-server/src/toolchain/scarb.rs +++ b/crates/cairo-lang-language-server/src/toolchain/scarb.rs @@ -2,12 +2,12 @@ use std::path::{Path, PathBuf}; use std::sync::{Arc, OnceLock}; use anyhow::{bail, Context, Result}; +use lsp_types::notification::Notification; use scarb_metadata::{Metadata, MetadataCommand}; -use tower_lsp::lsp_types::notification::Notification; use tracing::{error, warn}; use crate::env_config; -use crate::server::notifier::Notifier; +use crate::server::client::Notifier; pub const SCARB_TOML: &str = "Scarb.toml"; @@ -34,12 +34,8 @@ pub struct ScarbToolchain { impl ScarbToolchain { /// Constructs a new [`ScarbToolchain`]. - pub fn new(notifier: &Notifier) -> Self { - ScarbToolchain { - scarb_path_cell: Default::default(), - notifier: notifier.clone(), - is_silent: false, - } + pub fn new(notifier: Notifier) -> Self { + ScarbToolchain { scarb_path_cell: Default::default(), notifier, is_silent: false } } /// Finds the path to the `scarb` executable to use. @@ -58,7 +54,7 @@ impl ScarbToolchain { warn!("attempt to use scarb without SCARB env being set"); } else { error!("attempt to use scarb without SCARB env being set"); - self.notifier.send_notification::(()); + self.notifier.notify::(()); } } path @@ -117,7 +113,7 @@ impl ScarbToolchain { }; if !self.is_silent { - self.notifier.send_notification::(()); + self.notifier.notify::(()); } let result = MetadataCommand::new() @@ -128,7 +124,7 @@ impl ScarbToolchain { .context("failed to execute: scarb metadata"); if !self.is_silent { - self.notifier.send_notification::(()); + self.notifier.notify::(()); } result diff --git a/crates/cairo-lang-language-server/tests/e2e/analysis.rs b/crates/cairo-lang-language-server/tests/e2e/analysis.rs index 57c3e2a34ff..a1e6e081955 100644 --- a/crates/cairo-lang-language-server/tests/e2e/analysis.rs +++ b/crates/cairo-lang-language-server/tests/e2e/analysis.rs @@ -1,7 +1,7 @@ use cairo_lang_language_server::lsp; use indoc::indoc; +use lsp_types::{lsp_request, ExecuteCommandParams}; use pretty_assertions::assert_eq; -use tower_lsp::lsp_types::{lsp_request, ApplyWorkspaceEditResponse, ExecuteCommandParams}; use crate::support::normalize::normalize; use crate::support::sandbox; @@ -119,11 +119,6 @@ fn test_reload() { let expected = ls.send_request::(()); - ls.expect_request::(|_| ApplyWorkspaceEditResponse { - applied: true, - failure_reason: None, - failed_change: None, - }); ls.send_request::(ExecuteCommandParams { command: "cairo.reload".into(), ..Default::default() diff --git a/crates/cairo-lang-language-server/tests/e2e/code_actions.rs b/crates/cairo-lang-language-server/tests/e2e/code_actions.rs index 95076517e59..435982360f7 100644 --- a/crates/cairo-lang-language-server/tests/e2e/code_actions.rs +++ b/crates/cairo-lang-language-server/tests/e2e/code_actions.rs @@ -1,6 +1,6 @@ use cairo_lang_test_utils::parse_test_file::TestRunnerResult; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use tower_lsp::lsp_types::{ +use lsp_types::{ lsp_request, ClientCapabilities, CodeActionContext, CodeActionOrCommand, CodeActionParams, HoverClientCapabilities, MarkupKind, Range, TextDocumentClientCapabilities, }; diff --git a/crates/cairo-lang-language-server/tests/e2e/completions.rs b/crates/cairo-lang-language-server/tests/e2e/completions.rs index e8901e84703..344ddd10058 100644 --- a/crates/cairo-lang-language-server/tests/e2e/completions.rs +++ b/crates/cairo-lang-language-server/tests/e2e/completions.rs @@ -1,6 +1,6 @@ use cairo_lang_test_utils::parse_test_file::TestRunnerResult; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use tower_lsp::lsp_types::{lsp_request, CompletionParams, TextDocumentPositionParams}; +use lsp_types::{lsp_request, CompletionParams, TextDocumentPositionParams}; use crate::support::cursor::peek_caret; use crate::support::{cursors, sandbox}; @@ -56,8 +56,8 @@ fn test_completions_text_edits( ls.send_request::(completion_params); if let Some(completions) = caret_completions { let completion_items = match completions { - tower_lsp::lsp_types::CompletionResponse::Array(items) => items, - tower_lsp::lsp_types::CompletionResponse::List(list) => list.items, + lsp_types::CompletionResponse::Array(items) => items, + lsp_types::CompletionResponse::List(list) => list.items, }; for completion in completion_items { if let Some(text_edit) = completion.additional_text_edits { diff --git a/crates/cairo-lang-language-server/tests/e2e/goto.rs b/crates/cairo-lang-language-server/tests/e2e/goto.rs index b750f3fb32a..36114b98fe2 100644 --- a/crates/cairo-lang-language-server/tests/e2e/goto.rs +++ b/crates/cairo-lang-language-server/tests/e2e/goto.rs @@ -1,6 +1,6 @@ use cairo_lang_test_utils::parse_test_file::TestRunnerResult; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use tower_lsp::lsp_types::{ +use lsp_types::{ lsp_request, ClientCapabilities, GotoCapability, GotoDefinitionParams, GotoDefinitionResponse, TextDocumentClientCapabilities, TextDocumentIdentifier, TextDocumentPositionParams, }; diff --git a/crates/cairo-lang-language-server/tests/e2e/hover.rs b/crates/cairo-lang-language-server/tests/e2e/hover.rs index d758be12005..13e28e04813 100644 --- a/crates/cairo-lang-language-server/tests/e2e/hover.rs +++ b/crates/cairo-lang-language-server/tests/e2e/hover.rs @@ -1,6 +1,6 @@ use cairo_lang_test_utils::parse_test_file::TestRunnerResult; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use tower_lsp::lsp_types::{ +use lsp_types::{ lsp_request, ClientCapabilities, Hover, HoverClientCapabilities, HoverContents, HoverParams, MarkupContent, MarkupKind, TextDocumentClientCapabilities, TextDocumentPositionParams, }; diff --git a/crates/cairo-lang-language-server/tests/e2e/macro_expand.rs b/crates/cairo-lang-language-server/tests/e2e/macro_expand.rs index 67e15d02595..78dcd47295d 100644 --- a/crates/cairo-lang-language-server/tests/e2e/macro_expand.rs +++ b/crates/cairo-lang-language-server/tests/e2e/macro_expand.rs @@ -1,7 +1,7 @@ use cairo_lang_language_server::lsp::ext::ExpandMacro; use cairo_lang_test_utils::parse_test_file::TestRunnerResult; use cairo_lang_utils::ordered_hash_map::OrderedHashMap; -use tower_lsp::lsp_types::{TextDocumentIdentifier, TextDocumentPositionParams}; +use lsp_types::{TextDocumentIdentifier, TextDocumentPositionParams}; use crate::support::cursor::peek_caret; use crate::support::{cursors, sandbox}; diff --git a/crates/cairo-lang-language-server/tests/e2e/semantic_tokens.rs b/crates/cairo-lang-language-server/tests/e2e/semantic_tokens.rs index 5fb9a04f70e..b078f3e182b 100644 --- a/crates/cairo-lang-language-server/tests/e2e/semantic_tokens.rs +++ b/crates/cairo-lang-language-server/tests/e2e/semantic_tokens.rs @@ -1,5 +1,4 @@ use lsp_types::lsp_request; -use tower_lsp::lsp_types; use crate::support::sandbox; diff --git a/crates/cairo-lang-language-server/tests/e2e/support/client_capabilities.rs b/crates/cairo-lang-language-server/tests/e2e/support/client_capabilities.rs index 6ad0546c8c5..8d8dc33757d 100644 --- a/crates/cairo-lang-language-server/tests/e2e/support/client_capabilities.rs +++ b/crates/cairo-lang-language-server/tests/e2e/support/client_capabilities.rs @@ -1,5 +1,3 @@ -use tower_lsp::lsp_types; - /// Produces minimal client capabilities provided by the mock language client. /// /// Tests will most often need to extend these with test-specific additions using the diff --git a/crates/cairo-lang-language-server/tests/e2e/support/cursor.rs b/crates/cairo-lang-language-server/tests/e2e/support/cursor.rs index afce1c08295..d817dddbcc5 100644 --- a/crates/cairo-lang-language-server/tests/e2e/support/cursor.rs +++ b/crates/cairo-lang-language-server/tests/e2e/support/cursor.rs @@ -2,7 +2,7 @@ use std::cmp::min; use std::str::Chars; use itertools::{Itertools, MultiPeek}; -use tower_lsp::lsp_types::{Position, Range}; +use lsp_types::{Position, Range}; #[path = "cursor_test.rs"] mod test; diff --git a/crates/cairo-lang-language-server/tests/e2e/support/cursor_test.rs b/crates/cairo-lang-language-server/tests/e2e/support/cursor_test.rs index 9eec4bed8be..3ddb7b7cc27 100644 --- a/crates/cairo-lang-language-server/tests/e2e/support/cursor_test.rs +++ b/crates/cairo-lang-language-server/tests/e2e/support/cursor_test.rs @@ -1,4 +1,4 @@ -use tower_lsp::lsp_types::Position; +use lsp_types::Position; use super::cursors; diff --git a/crates/cairo-lang-language-server/tests/e2e/support/fixture.rs b/crates/cairo-lang-language-server/tests/e2e/support/fixture.rs index 022dc6666af..b861c59461c 100644 --- a/crates/cairo-lang-language-server/tests/e2e/support/fixture.rs +++ b/crates/cairo-lang-language-server/tests/e2e/support/fixture.rs @@ -3,7 +3,7 @@ use std::path::{Path, PathBuf}; use assert_fs::prelude::*; use assert_fs::TempDir; -use tower_lsp::lsp_types::Url; +use lsp_types::Url; /// A temporary directory that is a context for testing the language server. pub struct Fixture { diff --git a/crates/cairo-lang-language-server/tests/e2e/support/jsonrpc.rs b/crates/cairo-lang-language-server/tests/e2e/support/jsonrpc.rs index 6af7dd1f88c..3c683057656 100644 --- a/crates/cairo-lang-language-server/tests/e2e/support/jsonrpc.rs +++ b/crates/cairo-lang-language-server/tests/e2e/support/jsonrpc.rs @@ -1,49 +1,17 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tower_lsp::jsonrpc::{Id, Request, Response, Result}; - -/// An incoming or outgoing JSON-RPC message. -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub enum Message { - /// A response message. - Response(Response), - /// A request or notification message. - Request(Request), -} - -impl Message { - /// Creates a JSON-RPC request message from untyped parts. - pub fn request(method: &'static str, id: Id, params: Value) -> Message { - let mut b = Request::build(method).id(id); - if !params.is_null() { - b = b.params(params); - } - Message::Request(b.finish()) - } - - /// Creates a JSON-RPC notification message from untyped parts. - pub fn notification(method: &'static str, params: Value) -> Message { - Message::Request(Request::build(method).params(params).finish()) - } - - /// Creates a JSON-RPC response message from untyped parts. - pub fn response(id: Id, result: Result) -> Message { - Message::Response(Response::from_parts(id, result)) - } -} +use lsp_server::RequestId; /// A utility object for generating unique IDs for JSON-RPC requests. #[derive(Default)] pub struct RequestIdGenerator { - next_id: i64, + next_id: i32, } impl RequestIdGenerator { /// Generates a new unique request ID. - pub fn next(&mut self) -> Id { + pub fn next(&mut self) -> RequestId { let id = self.next_id; self.next_id = self.next_id.wrapping_add(1); - Id::Number(id) + + id.into() } } diff --git a/crates/cairo-lang-language-server/tests/e2e/support/mock_client.rs b/crates/cairo-lang-language-server/tests/e2e/support/mock_client.rs index fb0cbf9571f..130075f4a24 100644 --- a/crates/cairo-lang-language-server/tests/e2e/support/mock_client.rs +++ b/crates/cairo-lang-language-server/tests/e2e/support/mock_client.rs @@ -1,23 +1,17 @@ use std::collections::VecDeque; use std::ffi::OsStr; use std::path::Path; -use std::sync::Arc; use std::time::Duration; -use std::{fmt, future, mem, process}; +use std::{fmt, mem, process}; use cairo_lang_language_server::build_service_for_e2e_tests; -use futures::channel::mpsc; -use futures::{join, stream, FutureExt, SinkExt, StreamExt, TryFutureExt}; -use lsp_types::request::Request as LspRequest; +use lsp_server::{Message, Notification, Request, Response}; +use lsp_types::request::{RegisterCapability, Request as LspRequest}; use lsp_types::{lsp_notification, lsp_request}; use serde_json::Value; -use tokio::time::timeout; -use tower_lsp::{jsonrpc, lsp_types, ClientSocket, LanguageServer, LspService}; -use tower_service::Service; use crate::support::fixture::Fixture; -use crate::support::jsonrpc::{Message, RequestIdGenerator}; -use crate::support::runtime::{AbortOnDrop, GuardedRuntime}; +use crate::support::jsonrpc::RequestIdGenerator; /// A mock language client implementation that facilitates end-to-end testing language servers. /// @@ -25,21 +19,15 @@ use crate::support::runtime::{AbortOnDrop, GuardedRuntime}; /// /// The language server is terminated abruptly upon dropping of this struct. /// The `shutdown` request and `exit` notifications are not sent at all. -/// Instead, the Tokio Runtime executing the server is being shut down and any running +/// Instead, the thread executing the server is being shut down and any running /// blocking tasks are given a small period of time to complete. pub struct MockClient { fixture: Fixture, - // NOTE: The runtime is wrapped in `Arc`, which is then cloned at usage places, so that we do - // not have to reference `*self` while trying to block on it. - // This enables `async` blocks (that are blocked on) to take that `*self` for themselves. - rt: Arc, req_id: RequestIdGenerator, - input_tx: mpsc::Sender, - output_rx: mpsc::Receiver, + client: lsp_server::Connection, trace: Vec, workspace_configuration: Value, expect_request_handlers: VecDeque, - _main_loop: AbortOnDrop, } impl MockClient { @@ -54,92 +42,24 @@ impl MockClient { capabilities: lsp_types::ClientCapabilities, workspace_configuration: Value, ) -> Self { - let rt = Arc::new(GuardedRuntime::start()); - let (service, loopback) = build_service_for_e2e_tests(); - - let (requests_tx, requests_rx) = mpsc::channel(0); - let (responses_tx, responses_rx) = mpsc::channel(0); - - let main_loop = rt - .spawn(Self::serve(service, loopback, requests_rx, responses_tx.clone())) - .abort_handle() - .into(); + let (init, client) = build_service_for_e2e_tests(); let mut this = Self { fixture, - rt, + client, req_id: RequestIdGenerator::default(), - input_tx: requests_tx, - output_rx: responses_rx, trace: Vec::new(), workspace_configuration, expect_request_handlers: Default::default(), - _main_loop: main_loop, }; + std::thread::spawn(|| init().run_for_tests()); + this.initialize(capabilities); this } - /// Copy-paste of [`tower_lsp::Server::serve`] that skips IO serialization. - async fn serve( - mut service: LspService, - loopback: ClientSocket, - mut requests_rx: mpsc::Receiver, - mut responses_tx: mpsc::Sender, - ) { - let (client_requests, mut client_responses) = loopback.split(); - let (client_requests, client_abort) = stream::abortable(client_requests); - let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(100); - - let process_server_tasks = server_tasks_rx - .buffer_unordered(4) - .filter_map(future::ready) - .map(Message::Response) - .map(Ok) - .forward(responses_tx.clone().sink_map_err(|_| unreachable!())) - .map(|_| ()); - - let print_output = client_requests - .map(Message::Request) - .map(Ok) - .forward(responses_tx.clone().sink_map_err(|_| unreachable!())) - .map(|_| ()); - - let read_input = async { - while let Some(msg) = requests_rx.next().await { - match msg { - Message::Request(req) => { - if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await { - eprintln!("{err:?}"); - break; - } - - let fut = service.call(req).unwrap_or_else(|err| { - eprintln!("{err:?}"); - None - }); - - server_tasks_tx.send(fut).await.unwrap() - } - Message::Response(res) => { - if let Err(err) = client_responses.send(res).await { - eprintln!("{err:?}"); - break; - } - } - } - } - - server_tasks_tx.disconnect(); - responses_tx.disconnect(); - client_abort.abort(); - }; - - join!(print_output, read_input, process_server_tasks); - } - /// Performs the `initialize`/`initialized` handshake with the server synchronously. fn initialize(&mut self, capabilities: lsp_types::ClientCapabilities) { let workspace_folders = Some(vec![lsp_types::WorkspaceFolder { @@ -159,6 +79,8 @@ impl MockClient { ..lsp_types::InitializeParams::default() }); + self.expect_request::(|_req| {}); + self.send_notification::(lsp_types::InitializedParams {}); } @@ -172,56 +94,49 @@ impl MockClient { /// Sends an arbitrary request to the server. pub fn send_request_untyped(&mut self, method: &'static str, params: Value) -> Value { let id = self.req_id.next(); - let message = Message::request(method, id.clone(), params); + let message = Message::Request(Request::new(id.clone(), method.to_owned(), params)); let mut expect_request_handlers = mem::take(&mut self.expect_request_handlers); let does_expect_requests = !expect_request_handlers.is_empty(); - let rt = self.rt.clone(); - rt.block_on(async { - self.input_tx.send(message.clone()).await.expect("failed to send request"); + self.client.sender.send(message.clone()).expect("failed to send request"); - while let Some(response_message) = - self.recv().await.unwrap_or_else(|err| panic!("{err:?}: {message:?}")) - { - match response_message { - Message::Request(res) if res.id().is_none() => { - // This looks like a notification, skip it. - } + while let Some(response_message) = + self.recv().unwrap_or_else(|err| panic!("{err:?}: {message:?}")) + { + match response_message { + Message::Notification(_) => { + // Skip notifications. + } - Message::Request(req) => { - if does_expect_requests { - if let Some(handler) = expect_request_handlers.pop_front() { - let response = (handler.f)(&req); - let message = Message::Response(response); - self.input_tx.send(message).await.expect("failed to send response"); - continue; - } + Message::Request(req) => { + if does_expect_requests { + if let Some(handler) = expect_request_handlers.pop_front() { + let response = (handler.f)(&req); + let message = Message::Response(response); + self.client.sender.send(message).expect("failed to send response"); + continue; } - - panic!("unexpected request: {:?}", req) } - Message::Response(res) => { - let (res_id, result) = res.into_parts(); - assert_eq!(res_id, id); + panic!("unexpected request: {:?}", req) + } - assert!( - !does_expect_requests || expect_request_handlers.is_empty(), - "expected more requests to be received from the client while \ - processing the current server one: {expect_request_handlers:?}" - ); + Message::Response(res) => { + let res_id = res.id; + let result = res.result.ok_or_else(|| res.error.unwrap()); - match result { - Ok(result) => return result, - Err(err) => panic!("error response: {:#?}", err), - } + assert_eq!(res_id, id); + + match result { + Ok(result) => return result, + Err(err) => panic!("error response: {:#?}", err), } } } + } - panic!("no response for request: {message:?}") - }) + panic!("no response for request: {message:?}") } /// Sends a typed notification to the server. @@ -235,10 +150,8 @@ impl MockClient { /// Sends an arbitrary notification to the server. pub fn send_notification_untyped(&mut self, method: &'static str, params: Value) { - let message = Message::notification(method, params); - self.rt.block_on(async { - self.input_tx.send(message).await.expect("failed to send notification"); - }) + let message = Message::Notification(Notification::new(method.to_string(), params)); + self.client.sender.send(message).expect("failed to send notification"); } } @@ -256,25 +169,23 @@ enum RecvError { NoMessage, } -impl From for RecvError { - fn from(_: tokio::time::error::Elapsed) -> Self { - RecvError::Timeout - } -} - /// Receiving messages. impl MockClient { /// Receives a message from the server. - async fn recv(&mut self) -> Result, RecvError> { + fn recv(&mut self) -> Result, RecvError> { const TIMEOUT: Duration = Duration::from_secs(3 * 60); - let message = timeout(TIMEOUT, self.output_rx.next()).await?; + let message = match self.client.receiver.recv_timeout(TIMEOUT) { + Ok(msg) => Some(msg), + Err(crossbeam::channel::RecvTimeoutError::Disconnected) => None, + Err(crossbeam::channel::RecvTimeoutError::Timeout) => return Err(RecvError::Timeout), + }; if let Some(message) = &message { self.trace.push(message.clone()); if let Message::Request(request) = &message { - if request.method() == ::METHOD { - self.auto_respond_to_workspace_configuration_request(request).await; + if request.method == ::METHOD { + self.auto_respond_to_workspace_configuration_request(request); } } } @@ -284,7 +195,7 @@ impl MockClient { /// Looks for a message that satisfies the given predicate in message trace or waits for a new /// one. - async fn wait_for_message( + fn wait_for_message( &mut self, predicate: impl Fn(&Message) -> Option, ) -> Result { @@ -295,7 +206,7 @@ impl MockClient { } loop { - let message = self.recv().await?.ok_or(RecvError::NoMessage)?; + let message = self.recv()?.ok_or(RecvError::NoMessage)?; if let Some(ret) = predicate(&message) { return Ok(ret); } @@ -304,16 +215,15 @@ impl MockClient { /// Looks for a client JSON-RPC request that satisfies the given predicate in message trace /// or waits for a new one. - fn wait_for_rpc_request(&mut self, predicate: impl Fn(&jsonrpc::Request) -> Option) -> T { - let rt = self.rt.clone(); - rt.block_on(async { - self.wait_for_message(|message| { - let Message::Request(req) = message else { return None }; - predicate(req) - }) - .await - .unwrap_or_else(|err| panic!("waiting for request failed: {err:?}")) + fn wait_for_rpc_notification( + &mut self, + predicate: impl Fn(&lsp_server::Notification) -> Option, + ) -> T { + self.wait_for_message(|message| { + let Message::Notification(notification) = message else { return None }; + predicate(notification) }) + .unwrap_or_else(|err| panic!("waiting for request failed: {err:?}")) } /// Looks for a typed client notification that satisfies the given predicate in message trace @@ -322,11 +232,11 @@ impl MockClient { where N: lsp_types::notification::Notification, { - self.wait_for_rpc_request(|req| { - if req.method() != N::METHOD { + self.wait_for_rpc_notification(|notification| { + if notification.method != N::METHOD { return None; } - let params = serde_json::from_value(req.params().cloned().unwrap_or_default()) + let params = serde_json::from_value(notification.params.clone()) .expect("failed to parse notification params"); predicate(¶ms).then_some(params) }) @@ -335,7 +245,7 @@ impl MockClient { /// Methods for handling interactive requests. impl MockClient { - /// Expect a specified request to be received from the served while processing the next client + /// Expect a specified request to be received from the server while processing the next client /// request. /// /// The handler is expected to return a response to the caught request. @@ -347,18 +257,16 @@ impl MockClient { R: lsp_types::request::Request, { self.expect_request_untyped(R::METHOD, move |req| { - assert_eq!(req.method(), R::METHOD); + assert_eq!(req.method, R::METHOD); - let Some(id) = req.id().cloned() else { - panic!("request ID is missing: {req:?}"); - }; + let id = req.id.clone(); - let params = serde_json::from_value(req.params().cloned().unwrap_or_default()) - .expect("failed to parse request params"); + let params = + serde_json::from_value(req.params.clone()).expect("failed to parse request params"); let result = handler(¶ms); let result = serde_json::to_value(result).expect("failed to serialize response"); - jsonrpc::Response::from_ok(id, result) + lsp_server::Response::new_ok(id, result) }) } @@ -369,7 +277,7 @@ impl MockClient { pub fn expect_request_untyped( &mut self, description: &'static str, - handler: impl FnOnce(&jsonrpc::Request) -> jsonrpc::Response + 'static, + handler: impl FnOnce(&lsp_server::Request) -> lsp_server::Response + 'static, ) { self.expect_request_handlers .push_back(ExpectRequestHandler { description, f: Box::new(handler) }) @@ -434,30 +342,23 @@ impl MockClient { impl MockClient { /// Assuming `request` is a `workspace/configuration` request, computes and sends a response to /// it. - async fn auto_respond_to_workspace_configuration_request( - &mut self, - request: &jsonrpc::Request, - ) { - assert_eq!( - request.method(), - ::METHOD - ); + fn auto_respond_to_workspace_configuration_request(&mut self, request: &lsp_server::Request) { + assert_eq!(request.method, ::METHOD); - let id = request.id().cloned().expect("request ID is missing"); + let id = request.id.clone(); - let params = - serde_json::from_value(request.params().expect("request params are missing").clone()) - .expect("failed to parse `workspace/configuration` params"); + let params = serde_json::from_value(request.params.clone()) + .expect("failed to parse `workspace/configuration` params"); let result = self.compute_workspace_configuration(params); - let result = Ok(serde_json::to_value(result) - .expect("failed to serialize `workspace/configuration` response")); + let result = serde_json::to_value(result) + .expect("failed to serialize `workspace/configuration` response"); - let message = Message::response(id, result); - self.input_tx + let message = Message::Response(Response::new_ok(id, result)); + self.client + .sender .send(message) - .await .expect("failed to send `workspace/configuration` response"); } @@ -499,7 +400,7 @@ impl AsRef for MockClient { /// The description is used in panic messages. struct ExpectRequestHandler { description: &'static str, - f: Box jsonrpc::Response>, + f: Box lsp_server::Response>, } impl fmt::Debug for ExpectRequestHandler { diff --git a/crates/cairo-lang-language-server/tests/e2e/support/mod.rs b/crates/cairo-lang-language-server/tests/e2e/support/mod.rs index 120214c3a8e..6feeeacc58c 100644 --- a/crates/cairo-lang-language-server/tests/e2e/support/mod.rs +++ b/crates/cairo-lang-language-server/tests/e2e/support/mod.rs @@ -4,7 +4,6 @@ pub mod fixture; pub mod jsonrpc; mod mock_client; pub mod normalize; -mod runtime; pub use self::cursor::cursors; pub use self::mock_client::MockClient; diff --git a/crates/cairo-lang-language-server/tests/e2e/support/runtime.rs b/crates/cairo-lang-language-server/tests/e2e/support/runtime.rs deleted file mode 100644 index 341684d218b..00000000000 --- a/crates/cairo-lang-language-server/tests/e2e/support/runtime.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::ops::{Deref, DerefMut}; -use std::time::Duration; - -use tokio::runtime::Runtime; -use tokio::task::AbortHandle; - -/// A wrapper over a multithreaded [`Runtime`] that ensures it is properly shut down when dropped. -pub struct GuardedRuntime(Option); - -impl GuardedRuntime { - /// Starts a new multithreaded [`Runtime`]. - pub fn start() -> Self { - let runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .expect("failed to start runtime"); - Self(Some(runtime)) - } -} - -impl Drop for GuardedRuntime { - fn drop(&mut self) { - if let Some(runtime) = self.0.take() { - runtime.shutdown_timeout(Duration::from_millis(300)); - } - } -} - -impl Deref for GuardedRuntime { - type Target = Runtime; - - fn deref(&self) -> &Self::Target { - self.0.as_ref().expect("use after drop") - } -} - -impl DerefMut for GuardedRuntime { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.as_mut().expect("use after drop") - } -} - -/// A guard object which aborts a linked task when dropped. -pub struct AbortOnDrop(AbortHandle); - -impl Drop for AbortOnDrop { - fn drop(&mut self) { - self.0.abort(); - } -} - -impl From for AbortOnDrop { - fn from(handle: AbortHandle) -> Self { - Self(handle) - } -} diff --git a/crates/cairo-lang-language-server/tests/e2e/workspace_configuration.rs b/crates/cairo-lang-language-server/tests/e2e/workspace_configuration.rs index 872010b96da..c4908d476b1 100644 --- a/crates/cairo-lang-language-server/tests/e2e/workspace_configuration.rs +++ b/crates/cairo-lang-language-server/tests/e2e/workspace_configuration.rs @@ -1,9 +1,9 @@ use indoc::indoc; +use lsp_server::Message; +use lsp_types::lsp_request; +use lsp_types::request::Request as _; use serde_json::json; -use tower_lsp::lsp_types::lsp_request; -use tower_lsp::lsp_types::request::Request as _; -use crate::support::jsonrpc::Message; use crate::support::sandbox; /// The LS used to panic when some files in Salsa database were interned with a relative path. @@ -48,7 +48,7 @@ fn relative_path_to_core() { .iter() .filter(|msg| { let Message::Request(req) = msg else { return false }; - req.method() == ::METHOD + req.method == ::METHOD }) .count(), 1