diff --git a/crates/mun_language_server/Cargo.toml b/crates/mun_language_server/Cargo.toml index aadec71e6..e79adebb4 100644 --- a/crates/mun_language_server/Cargo.toml +++ b/crates/mun_language_server/Cargo.toml @@ -39,3 +39,4 @@ paths = {path="../mun_paths", package="mun_paths"} [dev-dependencies] tempdir = "0.3.7" mun_test = { path = "../mun_test"} +insta = "0.16" diff --git a/crates/mun_language_server/src/analysis.rs b/crates/mun_language_server/src/analysis.rs index dcc0f0659..b51c0e827 100644 --- a/crates/mun_language_server/src/analysis.rs +++ b/crates/mun_language_server/src/analysis.rs @@ -1,10 +1,8 @@ -use crate::cancelation::Canceled; -use crate::change::AnalysisChange; -use crate::db::AnalysisDatabase; -use crate::diagnostics; -use crate::diagnostics::Diagnostic; -use hir::line_index::LineIndex; -use hir::SourceDatabase; +use crate::{ + cancelation::Canceled, change::AnalysisChange, db::AnalysisDatabase, diagnostics, + diagnostics::Diagnostic, file_structure, +}; +use hir::{line_index::LineIndex, AstDatabase, SourceDatabase}; use salsa::{ParallelDatabase, Snapshot}; use std::sync::Arc; @@ -74,6 +72,14 @@ impl AnalysisSnapshot { self.with_db(|db| db.line_index(file_id)) } + /// Returns a tree structure of the symbols of a file. + pub fn file_structure( + &self, + file_id: hir::FileId, + ) -> Cancelable> { + self.with_db(|db| file_structure::file_structure(&db.parse(file_id).tree())) + } + /// Performs an operation on that may be Canceled. fn with_db T + std::panic::UnwindSafe, T>( &self, diff --git a/crates/mun_language_server/src/capabilities.rs b/crates/mun_language_server/src/capabilities.rs index 03a7e4651..4afbe5997 100644 --- a/crates/mun_language_server/src/capabilities.rs +++ b/crates/mun_language_server/src/capabilities.rs @@ -1,11 +1,12 @@ use lsp_types::{ - ClientCapabilities, ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, + ClientCapabilities, OneOf, ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, }; /// Returns the capabilities of this LSP server implementation given the capabilities of the client. pub fn server_capabilities(_client_caps: &ClientCapabilities) -> ServerCapabilities { ServerCapabilities { text_document_sync: Some(TextDocumentSyncCapability::Kind(TextDocumentSyncKind::Full)), + document_symbol_provider: Some(OneOf::Left(true)), ..Default::default() } } diff --git a/crates/mun_language_server/src/conversion.rs b/crates/mun_language_server/src/conversion.rs index 439fea79f..e3bece615 100644 --- a/crates/mun_language_server/src/conversion.rs +++ b/crates/mun_language_server/src/conversion.rs @@ -1,3 +1,4 @@ +use crate::symbol_kind::SymbolKind; use lsp_types::Url; use mun_syntax::{TextRange, TextUnit}; use paths::AbsPathBuf; @@ -74,3 +75,12 @@ pub fn convert_uri(uri: &Url) -> anyhow::Result { .and_then(|path| AbsPathBuf::try_from(path).ok()) .ok_or_else(|| anyhow::anyhow!("invalid uri: {}", uri)) } + +/// Converts a symbol kind from this crate to one for the LSP protocol. +pub fn convert_symbol_kind(symbol_kind: SymbolKind) -> lsp_types::SymbolKind { + match symbol_kind { + SymbolKind::Function => lsp_types::SymbolKind::Function, + SymbolKind::Struct => lsp_types::SymbolKind::Struct, + SymbolKind::TypeAlias => lsp_types::SymbolKind::TypeParameter, + } +} diff --git a/crates/mun_language_server/src/file_structure.rs b/crates/mun_language_server/src/file_structure.rs new file mode 100644 index 000000000..4b28805af --- /dev/null +++ b/crates/mun_language_server/src/file_structure.rs @@ -0,0 +1,132 @@ +use crate::SymbolKind; +use mun_syntax::{ + ast::{self, NameOwner}, + match_ast, AstNode, SourceFile, SyntaxNode, TextRange, WalkEvent, +}; + +/// A description of a symbol in a source file. +#[derive(Debug, Clone)] +pub struct StructureNode { + /// An optional parent of this symbol. Refers to the index of the symbol in the collection that + /// this instance resides in. + pub parent: Option, + + /// The text label + pub label: String, + + /// The range to navigate to if selected + pub navigation_range: TextRange, + + /// The entire range of the node in the file + pub node_range: TextRange, + + /// The type of symbol + pub kind: SymbolKind, + + /// Optional detailed information + pub detail: Option, +} + +/// Provides a tree of symbols defined in a `SourceFile`. +pub(crate) fn file_structure(file: &SourceFile) -> Vec { + let mut result = Vec::new(); + let mut stack = Vec::new(); + + for event in file.syntax().preorder() { + match event { + WalkEvent::Enter(node) => { + if let Some(mut symbol) = try_convert_to_structure_node(&node) { + symbol.parent = stack.last().copied(); + stack.push(result.len()); + result.push(symbol); + } + } + WalkEvent::Leave(node) => { + if try_convert_to_structure_node(&node).is_some() { + stack.pop().unwrap(); + } + } + } + } + + result +} + +/// Tries to convert an ast node to something that would reside in the hierarchical file structure. +fn try_convert_to_structure_node(node: &SyntaxNode) -> Option { + /// Create a `StructureNode` from a declaration + fn decl(node: N, kind: SymbolKind) -> Option { + decl_with_detail(&node, None, kind) + } + + /// Create a `StructureNode` from a declaration with extra text detail + fn decl_with_detail( + node: &N, + detail: Option, + kind: SymbolKind, + ) -> Option { + let name = node.name()?; + + Some(StructureNode { + parent: None, + label: name.text().to_string(), + navigation_range: name.syntax().text_range(), + node_range: node.syntax().text_range(), + kind, + detail, + }) + } + + /// Given a `SyntaxNode` get the text without any whitespaces + fn collapse_whitespaces(node: &SyntaxNode, output: &mut String) { + let mut can_insert_ws = false; + node.text().for_each_chunk(|chunk| { + for line in chunk.lines() { + let line = line.trim(); + if line.is_empty() { + if can_insert_ws { + output.push(' '); + can_insert_ws = false; + } + } else { + output.push_str(line); + can_insert_ws = true; + } + } + }) + } + + /// Given a `SyntaxNode` construct a `StructureNode` by referring to the type of a node. + fn decl_with_type_ref( + node: &N, + type_ref: Option, + kind: SymbolKind, + ) -> Option { + let detail = type_ref.map(|type_ref| { + let mut detail = String::new(); + collapse_whitespaces(type_ref.syntax(), &mut detail); + detail + }); + decl_with_detail(node, detail, kind) + } + + match_ast! { + match node { + ast::FunctionDef(it) => { + let mut detail = String::from("fn"); + if let Some(param_list) = it.param_list() { + collapse_whitespaces(param_list.syntax(), &mut detail); + } + if let Some(ret_type) = it.ret_type() { + detail.push(' '); + collapse_whitespaces(ret_type.syntax(), &mut detail); + } + + decl_with_detail(&it, Some(detail), SymbolKind::Function) + }, + ast::StructDef(it) => decl(it, SymbolKind::Struct), + ast::TypeAliasDef(it) => decl_with_type_ref(&it, it.type_ref(), SymbolKind::TypeAlias), + _ => None + } + } +} diff --git a/crates/mun_language_server/src/handlers.rs b/crates/mun_language_server/src/handlers.rs new file mode 100644 index 000000000..ad00dc740 --- /dev/null +++ b/crates/mun_language_server/src/handlers.rs @@ -0,0 +1,148 @@ +use crate::{ + conversion::{convert_range, convert_symbol_kind}, + state::LanguageServerSnapshot, +}; +use lsp_types::DocumentSymbol; + +/// Computes the document symbols for a specific document. Converts the LSP types to internal +/// formats and calls [`LanguageServerSnapshot::file_structure`] to fetch the symbols in the +/// requested document. Once completed, returns the result converted back to LSP types. +pub(crate) fn handle_document_symbol( + snapshot: LanguageServerSnapshot, + params: lsp_types::DocumentSymbolParams, +) -> anyhow::Result> { + let file_id = snapshot.uri_to_file_id(¶ms.text_document.uri)?; + let line_index = snapshot.analysis.file_line_index(file_id)?; + + let mut parents: Vec<(DocumentSymbol, Option)> = Vec::new(); + + for symbol in snapshot.analysis.file_structure(file_id)? { + #[allow(deprecated)] + let doc_symbol = DocumentSymbol { + name: symbol.label, + detail: symbol.detail, + kind: convert_symbol_kind(symbol.kind), + tags: None, + deprecated: None, + range: convert_range(symbol.node_range, &line_index), + selection_range: convert_range(symbol.navigation_range, &line_index), + children: None, + }; + + parents.push((doc_symbol, symbol.parent)); + } + + Ok(Some(build_hierarchy_from_flat_list(parents).into())) +} + +/// Constructs a hierarchy of DocumentSymbols for a list of symbols that specify which index is the +/// parent of a symbol. The parent index must always be smaller than the current index. +fn build_hierarchy_from_flat_list( + mut symbols_and_parent: Vec<(DocumentSymbol, Option)>, +) -> Vec { + let mut result = Vec::new(); + + // Iterate over all elements in the list from back to front. + while let Some((mut node, parent_index)) = symbols_and_parent.pop() { + // If this node has children (added by the code below), they are in the reverse order. This + // is because we iterate the input from back to front. + if let Some(children) = &mut node.children { + children.reverse(); + } + + // Get the parent index of the current node. + let parent = match parent_index { + // If the parent doesnt have a node, directly use the result vector (its a root). + None => &mut result, + + // If there is a parent, get a reference to the children vector of that parent. + Some(i) => symbols_and_parent[i] + .0 + .children + .get_or_insert_with(Vec::new), + }; + + parent.push(node); + } + + // The items where pushed in the reverse order, so reverse it right back + result.reverse(); + result +} + +#[cfg(test)] +mod tests { + use crate::handlers::build_hierarchy_from_flat_list; + use lsp_types::{DocumentSymbol, SymbolKind}; + + #[test] + fn test_build_hierarchy_from_flat_list() { + #[allow(deprecated)] + let default_symbol = DocumentSymbol { + name: "".to_string(), + detail: None, + kind: SymbolKind::File, + tags: None, + deprecated: None, + range: Default::default(), + selection_range: Default::default(), + children: None, + }; + + let mut list = Vec::new(); + + list.push(( + DocumentSymbol { + name: "a".to_string(), + ..default_symbol.clone() + }, + None, + )); + + list.push(( + DocumentSymbol { + name: "b".to_string(), + ..default_symbol.clone() + }, + Some(0), + )); + + list.push(( + DocumentSymbol { + name: "c".to_string(), + ..default_symbol.clone() + }, + Some(0), + )); + + list.push(( + DocumentSymbol { + name: "d".to_string(), + ..default_symbol.clone() + }, + Some(1), + )); + + assert_eq!( + build_hierarchy_from_flat_list(list), + vec![DocumentSymbol { + name: "a".to_string(), + children: Some(vec![ + DocumentSymbol { + name: "b".to_string(), + children: Some(vec![DocumentSymbol { + name: "d".to_string(), + ..default_symbol.clone() + }]), + ..default_symbol.clone() + }, + DocumentSymbol { + name: "c".to_string(), + ..default_symbol.clone() + } + ]), + ..default_symbol.clone() + }] + ) + } +} diff --git a/crates/mun_language_server/src/lib.rs b/crates/mun_language_server/src/lib.rs index c1544d94d..6d2027430 100644 --- a/crates/mun_language_server/src/lib.rs +++ b/crates/mun_language_server/src/lib.rs @@ -7,6 +7,7 @@ pub use main_loop::main_loop; use paths::AbsPathBuf; use project::ProjectManifest; pub(crate) use state::LanguageServerState; +pub(crate) use symbol_kind::SymbolKind; mod analysis; mod cancelation; @@ -16,8 +17,11 @@ mod config; mod conversion; mod db; mod diagnostics; +mod file_structure; +mod handlers; mod main_loop; mod state; +mod symbol_kind; /// Deserializes a `T` from a json value. pub fn from_json( diff --git a/crates/mun_language_server/src/state.rs b/crates/mun_language_server/src/state.rs index 1f062f6c1..75bfa4be5 100644 --- a/crates/mun_language_server/src/state.rs +++ b/crates/mun_language_server/src/state.rs @@ -7,14 +7,14 @@ use crate::{ to_json, }; use crossbeam_channel::{select, unbounded, Receiver, Sender}; -use lsp_server::ReqQueue; +use lsp_server::{ReqQueue, Response}; use lsp_types::{ notification::Notification, notification::PublishDiagnostics, PublishDiagnosticsParams, Url, }; use parking_lot::RwLock; use paths::AbsPathBuf; use rustc_hash::FxHashSet; -use std::{ops::Deref, sync::Arc, time::Instant}; +use std::{convert::TryFrom, ops::Deref, sync::Arc, time::Instant}; use vfs::VirtualFileSystem; mod protocol; @@ -25,6 +25,7 @@ mod workspace; /// enables synchronizing resources like the connection with the client. #[derive(Debug)] pub(crate) enum Task { + Response(Response), Notify(lsp_server::Notification), } @@ -192,6 +193,7 @@ impl LanguageServerState { Task::Notify(notification) => { self.send(notification.into()); } + Task::Response(response) => self.respond(response), } Ok(()) } @@ -375,6 +377,23 @@ impl LanguageServerSnapshot { Ok(url) } + + /// Converts the specified `Url` to a `hir::FileId` + pub fn uri_to_file_id(&self, url: &Url) -> anyhow::Result { + url.to_file_path() + .map_err(|_| anyhow::anyhow!("invalid uri: {}", url)) + .and_then(|path| { + AbsPathBuf::try_from(path) + .map_err(|_| anyhow::anyhow!("url does not refer to absolute path: {}", url)) + }) + .and_then(|path| { + self.vfs + .read() + .file_id(&path) + .ok_or_else(|| anyhow::anyhow!("url does not refer to a file: {}", url)) + .map(|id| hir::FileId(id.0)) + }) + } } impl Drop for LanguageServerState { diff --git a/crates/mun_language_server/src/state/protocol.rs b/crates/mun_language_server/src/state/protocol.rs index 7192439eb..65bc55678 100644 --- a/crates/mun_language_server/src/state/protocol.rs +++ b/crates/mun_language_server/src/state/protocol.rs @@ -1,5 +1,5 @@ use super::LanguageServerState; -use crate::{conversion::convert_uri, state::RequestHandler}; +use crate::{conversion::convert_uri, handlers, state::RequestHandler}; use anyhow::Result; use dispatcher::{NotificationDispatcher, RequestDispatcher}; use lsp_types::notification::{ @@ -87,10 +87,11 @@ impl LanguageServerState { // Dispatch the event based on the type of event RequestDispatcher::new(self, request) - .on::(|state, _request| { + .on_sync::(|state, _request| { state.shutdown_requested = true; Ok(()) })? + .on::(handlers::handle_document_symbol)? .finish(); Ok(()) @@ -148,7 +149,7 @@ impl LanguageServerState { /// Sends a response to the client. This method logs the time it took us to reply /// to a request from the client. - fn respond(&mut self, response: lsp_server::Response) { + pub(super) fn respond(&mut self, response: lsp_server::Response) { if let Some((_method, start)) = self.request_queue.incoming.complete(response.id.clone()) { let duration = start.elapsed(); log::info!("handled req#{} in {:?}", response.id, duration); diff --git a/crates/mun_language_server/src/state/protocol/dispatcher.rs b/crates/mun_language_server/src/state/protocol/dispatcher.rs index f2b53d2ff..fd66a906f 100644 --- a/crates/mun_language_server/src/state/protocol/dispatcher.rs +++ b/crates/mun_language_server/src/state/protocol/dispatcher.rs @@ -1,7 +1,7 @@ use super::LanguageServerState; use crate::cancelation::is_canceled; use crate::from_json; -use anyhow::Result; +use crate::state::{LanguageServerSnapshot, Task}; use serde::de::DeserializeOwned; use serde::Serialize; @@ -20,11 +20,11 @@ impl<'a> RequestDispatcher<'a> { } } - /// Try to dispatch the event as the given Request type. - pub fn on( + /// Try to dispatch the event as the given Request type on the current thread. + pub fn on_sync( &mut self, - f: fn(&mut LanguageServerState, R::Params) -> Result, - ) -> Result<&mut Self> + compute_response_fn: fn(&mut LanguageServerState, R::Params) -> anyhow::Result, + ) -> anyhow::Result<&mut Self> where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + 'static, @@ -35,12 +35,42 @@ impl<'a> RequestDispatcher<'a> { None => return Ok(self), }; - let result = f(self.state, params); + let result = compute_response_fn(self.state, params); let response = result_to_response::(id, result); self.state.respond(response); Ok(self) } + /// Try to dispatch the event as the given Request type on the thread pool. + pub fn on( + &mut self, + compute_response_fn: fn(LanguageServerSnapshot, R::Params) -> anyhow::Result, + ) -> anyhow::Result<&mut Self> + where + R: lsp_types::request::Request + 'static, + R::Params: DeserializeOwned + 'static + Send, + R::Result: Serialize + 'static, + { + let (id, params) = match self.parse::() { + Some(it) => it, + None => return Ok(self), + }; + + self.state.thread_pool.execute({ + let snapshot = self.state.snapshot(); + let sender = self.state.task_sender.clone(); + + move || { + let result = compute_response_fn(snapshot, params); + sender + .send(Task::Response(result_to_response::(id, result))) + .unwrap(); + } + }); + + Ok(self) + } + /// Tries to parse the request as the specified type. If the request is of the specified type, /// the request is transferred and any subsequent call to this method will return None. If an /// error is encountered during parsing of the request parameters an error is send to the @@ -101,8 +131,8 @@ impl<'a> NotificationDispatcher<'a> { /// Try to dispatch the event as the given Notification type. pub fn on( &mut self, - f: fn(&mut LanguageServerState, N::Params) -> Result<()>, - ) -> Result<&mut Self> + handle_notification_fn: fn(&mut LanguageServerState, N::Params) -> anyhow::Result<()>, + ) -> anyhow::Result<&mut Self> where N: lsp_types::notification::Notification + 'static, N::Params: DeserializeOwned + Send + 'static, @@ -118,7 +148,7 @@ impl<'a> NotificationDispatcher<'a> { return Ok(self); } }; - f(self.state, params)?; + handle_notification_fn(self.state, params)?; Ok(self) } @@ -136,7 +166,7 @@ impl<'a> NotificationDispatcher<'a> { /// may have occurred. fn result_to_response( id: lsp_server::RequestId, - result: Result, + result: anyhow::Result, ) -> lsp_server::Response where R: lsp_types::request::Request + 'static, diff --git a/crates/mun_language_server/src/symbol_kind.rs b/crates/mun_language_server/src/symbol_kind.rs new file mode 100644 index 000000000..0f5130800 --- /dev/null +++ b/crates/mun_language_server/src/symbol_kind.rs @@ -0,0 +1,7 @@ +/// Defines a set of symbols that can live in a document. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum SymbolKind { + Function, + Struct, + TypeAlias, +} diff --git a/crates/mun_language_server/tests/initialization.rs b/crates/mun_language_server/tests/initialization.rs index 722a775d2..b98ee7fb3 100644 --- a/crates/mun_language_server/tests/initialization.rs +++ b/crates/mun_language_server/tests/initialization.rs @@ -20,3 +20,32 @@ fn add(a: i32, b: i32) -> i32 { .server() .wait_until_workspace_is_loaded(); } + +#[test] +fn test_document_symbols() { + let server = Project::with_fixture( + r#" + //- /mun.toml + [package] + name = "foo" + version = "0.0.0" + + //- /src/mod.mun + fn main() -> i32 {} + struct Foo {} + type Bar = Foo; + "#, + ) + .server() + .wait_until_workspace_is_loaded(); + + let symbols = server.send_request::( + lsp_types::DocumentSymbolParams { + text_document: server.doc_id("src/mod.mun"), + work_done_progress_params: Default::default(), + partial_result_params: Default::default(), + }, + ); + + insta::assert_debug_snapshot!(symbols); +} diff --git a/crates/mun_language_server/tests/snapshots/initialization__document_symbols.snap b/crates/mun_language_server/tests/snapshots/initialization__document_symbols.snap new file mode 100644 index 000000000..591a18382 --- /dev/null +++ b/crates/mun_language_server/tests/snapshots/initialization__document_symbols.snap @@ -0,0 +1,98 @@ +--- +source: crates/mun_language_server/tests/initialization.rs +expression: symbols +--- +Some( + Nested( + [ + DocumentSymbol { + name: "main", + detail: Some( + "fn() -> i32", + ), + kind: Function, + tags: None, + deprecated: None, + range: Range { + start: Position { + line: 0, + character: 0, + }, + end: Position { + line: 0, + character: 19, + }, + }, + selection_range: Range { + start: Position { + line: 0, + character: 3, + }, + end: Position { + line: 0, + character: 7, + }, + }, + children: None, + }, + DocumentSymbol { + name: "Foo", + detail: None, + kind: Struct, + tags: None, + deprecated: None, + range: Range { + start: Position { + line: 1, + character: 0, + }, + end: Position { + line: 1, + character: 13, + }, + }, + selection_range: Range { + start: Position { + line: 1, + character: 7, + }, + end: Position { + line: 1, + character: 10, + }, + }, + children: None, + }, + DocumentSymbol { + name: "Bar", + detail: Some( + "Foo", + ), + kind: TypeParameter, + tags: None, + deprecated: None, + range: Range { + start: Position { + line: 2, + character: 0, + }, + end: Position { + line: 2, + character: 15, + }, + }, + selection_range: Range { + start: Position { + line: 2, + character: 5, + }, + end: Position { + line: 2, + character: 8, + }, + }, + children: None, + }, + ], + ), +) diff --git a/crates/mun_language_server/tests/support.rs b/crates/mun_language_server/tests/support.rs index 94b8faa58..a6ab7a8bd 100644 --- a/crates/mun_language_server/tests/support.rs +++ b/crates/mun_language_server/tests/support.rs @@ -1,7 +1,8 @@ use crossbeam_channel::{after, select}; use lsp_server::{Connection, Message, Notification, Request}; use lsp_types::{ - notification::Exit, request::Shutdown, ProgressParams, ProgressParamsValue, WorkDoneProgress, + notification::Exit, request::Shutdown, ProgressParams, ProgressParamsValue, Url, + WorkDoneProgress, }; use mun_language_server::{main_loop, Config, FilesWatcher}; use mun_test::Fixture; @@ -75,7 +76,7 @@ pub struct Server { messages: RefCell>, worker: Option>, client: Connection, - _tmp_dir: tempdir::TempDir, + tmp_dir: tempdir::TempDir, } impl Server { @@ -92,7 +93,15 @@ impl Server { messages: RefCell::new(Vec::new()), worker: Some(worker), client, - _tmp_dir: tmp_dir, + tmp_dir, + } + } + + /// Returns the LSP TextDocumentIdentifier for the given path + pub fn doc_id(&self, rel_path: &str) -> lsp_types::TextDocumentIdentifier { + let path = self.tmp_dir.path().join(rel_path); + lsp_types::TextDocumentIdentifier { + uri: Url::from_file_path(path).unwrap(), } } @@ -130,19 +139,25 @@ impl Server { } /// Sends a request to the main loop and expects the specified value to be returned - fn assert_request( - &mut self, + fn assert_request_returns_value( + &self, params: R::Params, expected_response: Value, ) where R::Params: Serialize, { - let result = self.send_request::(params); + let result = self.send_request_for_value::(params); assert_eq!(result, expected_response); } + /// Sends a request to the language server, returning the response + pub fn send_request(&self, params: R::Params) -> R::Result { + let value = self.send_request_for_value::(params); + serde_json::from_value(value).unwrap() + } + /// Sends a request to main loop, returning the response - fn send_request(&self, params: R::Params) -> Value + fn send_request_for_value(&self, params: R::Params) -> Value where R::Params: Serialize, { @@ -210,7 +225,7 @@ impl Server { impl Drop for Server { fn drop(&mut self) { // Send the proper shutdown sequence to ensure the main loop terminates properly - self.assert_request::((), Value::Null); + self.assert_request_returns_value::((), Value::Null); self.notification::(()); // Cancel the main_loop diff --git a/crates/mun_syntax/src/lib.rs b/crates/mun_syntax/src/lib.rs index 15224f9f8..7d1b7f036 100644 --- a/crates/mun_syntax/src/lib.rs +++ b/crates/mun_syntax/src/lib.rs @@ -31,7 +31,7 @@ pub use crate::{ syntax_kind::SyntaxKind, syntax_node::{Direction, SyntaxElement, SyntaxNode, SyntaxToken, SyntaxTreeBuilder}, }; -pub use rowan::{SmolStr, TextRange, TextUnit}; +pub use rowan::{SmolStr, TextRange, TextUnit, WalkEvent}; use rowan::GreenNode; @@ -134,6 +134,31 @@ impl SourceFile { } } +/// Matches a `SyntaxNode` against an `ast` type. +/// +/// # Example: +/// +/// ```ignore +/// match_ast! { +/// match node { +/// ast::CallExpr(it) => { ... }, +/// _ => None, +/// } +/// } +/// ``` +#[macro_export] +macro_rules! match_ast { + (match $node:ident { $($tt:tt)* }) => { match_ast!(match ($node) { $($tt)* }) }; + + (match ($node:expr) { + $( ast::$ast:ident($it:ident) => $res:expr, )* + _ => $catch_all:expr $(,)? + }) => {{ + $( if let Some($it) = ast::$ast::cast($node.clone()) { $res } else )* + { $catch_all } + }}; +} + /// This tests does not assert anything and instead just shows off the crate's API. #[test] fn api_walkthrough() {