diff --git a/tooling/lsp/src/lib.rs b/tooling/lsp/src/lib.rs index 796137a8e9f..1c8e58dd8ff 100644 --- a/tooling/lsp/src/lib.rs +++ b/tooling/lsp/src/lib.rs @@ -66,6 +66,7 @@ mod requests; mod solver; mod types; mod utils; +mod visitor; #[cfg(test)] mod test_utils; diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index ad2082d75a8..f603d804872 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -11,13 +11,11 @@ use lsp_types::{CompletionItem, CompletionItemKind, CompletionParams, Completion use noirc_errors::{Location, Span}; use noirc_frontend::{ ast::{ - ArrayLiteral, AsTraitPath, BlockExpression, CallExpression, CastExpression, - ConstrainStatement, ConstructorExpression, Expression, ForLoopStatement, ForRange, - FunctionReturnType, Ident, IfExpression, IndexExpression, InfixExpression, LValue, Lambda, - LetStatement, Literal, MemberAccessExpression, MethodCallExpression, NoirFunction, - NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, PathKind, PathSegment, Pattern, - Statement, TraitImplItem, TraitItem, TypeImpl, UnresolvedGeneric, UnresolvedGenerics, - UnresolvedType, UseTree, UseTreeKind, + AsTraitPath, BlockExpression, ConstructorExpression, Expression, ForLoopStatement, + FunctionReturnType, Ident, IfExpression, Lambda, LetStatement, MemberAccessExpression, + NoirFunction, NoirStruct, NoirTraitImpl, NoirTypeAlias, Path, PathKind, PathSegment, + Pattern, Statement, TraitItem, TypeImpl, UnresolvedGeneric, UnresolvedGenerics, + UnresolvedType, UnresolvedTypeData, UseTree, UseTreeKind, }, graph::{CrateId, Dependency}, hir::{ @@ -26,12 +24,16 @@ use noirc_frontend::{ }, macros_api::{ModuleDefId, NodeInterner}, node_interner::ReferenceId, - parser::{Item, ItemKind}, + parser::{Item, ParsedSubModule}, ParsedModule, StructType, Type, }; use sort_text::underscore_sort_text; -use crate::{utils, LspState}; +use crate::{ + utils, + visitor::{Acceptor, ChildrenAcceptor, Visitor}, + LspState, +}; use super::process_request; @@ -136,7 +138,7 @@ impl<'a> NodeFinder<'a> { } fn find(&mut self, parsed_module: &ParsedModule) -> Option { - self.find_in_parsed_module(parsed_module); + parsed_module.accept(self); if self.completion_items.is_empty() { None @@ -154,490 +156,20 @@ impl<'a> NodeFinder<'a> { } } - fn find_in_parsed_module(&mut self, parsed_module: &ParsedModule) { - for item in &parsed_module.items { - self.find_in_item(item); - } - } - - fn find_in_item(&mut self, item: &Item) { - if !self.includes_span(item.span) { - return; - } - - match &item.kind { - ItemKind::Import(use_tree) => { - let mut prefixes = Vec::new(); - self.find_in_use_tree(use_tree, &mut prefixes); - } - ItemKind::Submodules(parsed_sub_module) => { - // Switch `self.module_id` to the submodule - let previous_module_id = self.module_id; - - let def_map = &self.def_maps[&self.module_id.krate]; - let Some(module_data) = def_map.modules().get(self.module_id.local_id.0) else { - return; - }; - if let Some(child_module) = module_data.children.get(&parsed_sub_module.name) { - self.module_id = - ModuleId { krate: self.module_id.krate, local_id: *child_module }; - } - - self.find_in_parsed_module(&parsed_sub_module.contents); - - // Restore the old module before continuing - self.module_id = previous_module_id; - } - ItemKind::Function(noir_function) => self.find_in_noir_function(noir_function), - ItemKind::TraitImpl(noir_trait_impl) => self.find_in_noir_trait_impl(noir_trait_impl), - ItemKind::Impl(type_impl) => self.find_in_type_impl(type_impl), - ItemKind::Global(let_statement) => self.find_in_let_statement(let_statement, false), - ItemKind::TypeAlias(noir_type_alias) => self.find_in_noir_type_alias(noir_type_alias), - ItemKind::Struct(noir_struct) => self.find_in_noir_struct(noir_struct), - ItemKind::Trait(noir_trait) => self.find_in_noir_trait(noir_trait), - ItemKind::ModuleDecl(_) => (), - } - } - - fn find_in_noir_function(&mut self, noir_function: &NoirFunction) { - let old_type_parameters = self.type_parameters.clone(); - self.collect_type_parameters_in_generics(&noir_function.def.generics); - - for param in &noir_function.def.parameters { - self.find_in_unresolved_type(¶m.typ); - } - - self.find_in_function_return_type(&noir_function.def.return_type); - - self.local_variables.clear(); - for param in &noir_function.def.parameters { - self.collect_local_variables(¶m.pattern); - } - - self.find_in_block_expression(&noir_function.def.body); - - self.type_parameters = old_type_parameters; - } - - fn find_in_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl) { - self.type_parameters.clear(); - self.collect_type_parameters_in_generics(&noir_trait_impl.impl_generics); - - for item in &noir_trait_impl.items { - self.find_in_trait_impl_item(item); - } - - self.type_parameters.clear(); - } - - fn find_in_trait_impl_item(&mut self, item: &TraitImplItem) { - match item { - TraitImplItem::Function(noir_function) => self.find_in_noir_function(noir_function), - TraitImplItem::Constant(_, _, _) => (), - TraitImplItem::Type { .. } => (), - } - } - - fn find_in_type_impl(&mut self, type_impl: &TypeImpl) { - self.type_parameters.clear(); - self.collect_type_parameters_in_generics(&type_impl.generics); - - for (method, span) in &type_impl.methods { - self.find_in_noir_function(method); - - // Optimization: stop looking in functions past the completion cursor - if span.end() as usize > self.byte_index { - break; - } - } - - self.type_parameters.clear(); - } - - fn find_in_noir_type_alias(&mut self, noir_type_alias: &NoirTypeAlias) { - self.find_in_unresolved_type(&noir_type_alias.typ); - } - - fn find_in_noir_struct(&mut self, noir_struct: &NoirStruct) { - self.type_parameters.clear(); - self.collect_type_parameters_in_generics(&noir_struct.generics); - - for (_name, unresolved_type) in &noir_struct.fields { - self.find_in_unresolved_type(unresolved_type); - } - - self.type_parameters.clear(); - } - - fn find_in_noir_trait(&mut self, noir_trait: &NoirTrait) { - for item in &noir_trait.items { - self.find_in_trait_item(item); - } - } - - fn find_in_trait_item(&mut self, trait_item: &TraitItem) { - match trait_item { - TraitItem::Function { - name: _, - generics, - parameters, - return_type, - where_clause, - body, - } => { - let old_type_parameters = self.type_parameters.clone(); - self.collect_type_parameters_in_generics(generics); - - for (_name, unresolved_type) in parameters { - self.find_in_unresolved_type(unresolved_type); - } - - self.find_in_function_return_type(return_type); - - for unresolved_trait_constraint in where_clause { - self.find_in_unresolved_type(&unresolved_trait_constraint.typ); - } - - if let Some(body) = body { - self.local_variables.clear(); - for (name, _) in parameters { - self.local_variables.insert(name.to_string(), name.span()); - } - self.find_in_block_expression(body); - }; - - self.type_parameters = old_type_parameters; - } - TraitItem::Constant { name: _, typ, default_value } => { - self.find_in_unresolved_type(typ); - - if let Some(default_value) = default_value { - self.find_in_expression(default_value); - } - } - TraitItem::Type { name: _ } => (), - } - } - - fn find_in_block_expression(&mut self, block_expression: &BlockExpression) { - let old_local_variables = self.local_variables.clone(); - for statement in &block_expression.statements { - self.find_in_statement(statement); - - // Optimization: stop looking in statements past the completion cursor - if statement.span.end() as usize > self.byte_index { - break; - } - } - self.local_variables = old_local_variables; - } - - fn find_in_statement(&mut self, statement: &Statement) { - match &statement.kind { - noirc_frontend::ast::StatementKind::Let(let_statement) => { - self.find_in_let_statement(let_statement, true); - } - noirc_frontend::ast::StatementKind::Constrain(constrain_statement) => { - self.find_in_constrain_statement(constrain_statement); - } - noirc_frontend::ast::StatementKind::Expression(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::StatementKind::Assign(assign_statement) => { - self.find_in_assign_statement(assign_statement); - } - noirc_frontend::ast::StatementKind::For(for_loop_statement) => { - self.find_in_for_loop_statement(for_loop_statement); - } - noirc_frontend::ast::StatementKind::Comptime(statement) => { - // When entering a comptime block, regular local variables shouldn't be offered anymore - let old_local_variables = self.local_variables.clone(); - self.local_variables.clear(); - - self.find_in_statement(statement); - - self.local_variables = old_local_variables; - } - noirc_frontend::ast::StatementKind::Semi(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::StatementKind::Break - | noirc_frontend::ast::StatementKind::Continue - | noirc_frontend::ast::StatementKind::Error => (), - } - } - fn find_in_let_statement( &mut self, let_statement: &LetStatement, collect_local_variables: bool, ) { self.find_in_unresolved_type(&let_statement.r#type); - self.find_in_expression(&let_statement.expression); + + let_statement.expression.accept(self); if collect_local_variables { self.collect_local_variables(&let_statement.pattern); } } - fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { - self.find_in_expression(&constrain_statement.0); - - if let Some(exp) = &constrain_statement.1 { - self.find_in_expression(exp); - } - } - - fn find_in_assign_statement( - &mut self, - assign_statement: &noirc_frontend::ast::AssignStatement, - ) { - self.find_in_lvalue(&assign_statement.lvalue); - self.find_in_expression(&assign_statement.expression); - } - - fn find_in_for_loop_statement(&mut self, for_loop_statement: &ForLoopStatement) { - let old_local_variables = self.local_variables.clone(); - let ident = &for_loop_statement.identifier; - self.local_variables.insert(ident.to_string(), ident.span()); - - self.find_in_for_range(&for_loop_statement.range); - self.find_in_expression(&for_loop_statement.block); - - self.local_variables = old_local_variables; - } - - fn find_in_lvalue(&mut self, lvalue: &LValue) { - match lvalue { - LValue::Ident(ident) => { - if self.byte == Some(b'.') && ident.span().end() as usize == self.byte_index - 1 { - let location = Location::new(ident.span(), self.file); - if let Some(ReferenceId::Local(definition_id)) = - self.interner.find_referenced(location) - { - let typ = self.interner.definition_type(definition_id); - let prefix = ""; - self.complete_type_fields_and_methods(&typ, prefix); - } - } - } - LValue::MemberAccess { object, field_name: _, span: _ } => self.find_in_lvalue(object), - LValue::Index { array, index, span: _ } => { - self.find_in_lvalue(array); - self.find_in_expression(index); - } - LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), - } - } - - fn find_in_for_range(&mut self, for_range: &ForRange) { - match for_range { - ForRange::Range(start, end) => { - self.find_in_expression(start); - self.find_in_expression(end); - } - ForRange::Array(expression) => self.find_in_expression(expression), - } - } - - fn find_in_expressions(&mut self, expressions: &[Expression]) { - for expression in expressions { - self.find_in_expression(expression); - } - } - - fn find_in_expression(&mut self, expression: &Expression) { - match &expression.kind { - noirc_frontend::ast::ExpressionKind::Literal(literal) => self.find_in_literal(literal), - noirc_frontend::ast::ExpressionKind::Block(block_expression) => { - self.find_in_block_expression(block_expression); - } - noirc_frontend::ast::ExpressionKind::Prefix(prefix_expression) => { - self.find_in_expression(&prefix_expression.rhs); - } - noirc_frontend::ast::ExpressionKind::Index(index_expression) => { - self.find_in_index_expression(index_expression); - } - noirc_frontend::ast::ExpressionKind::Call(call_expression) => { - self.find_in_call_expression(call_expression); - } - noirc_frontend::ast::ExpressionKind::MethodCall(method_call_expression) => { - self.find_in_method_call_expression(method_call_expression); - } - noirc_frontend::ast::ExpressionKind::Constructor(constructor_expression) => { - self.find_in_constructor_expression(constructor_expression); - } - noirc_frontend::ast::ExpressionKind::MemberAccess(member_access_expression) => { - self.find_in_member_access_expression(member_access_expression); - } - noirc_frontend::ast::ExpressionKind::Cast(cast_expression) => { - self.find_in_cast_expression(cast_expression); - } - noirc_frontend::ast::ExpressionKind::Infix(infix_expression) => { - self.find_in_infix_expression(infix_expression); - } - noirc_frontend::ast::ExpressionKind::If(if_expression) => { - self.find_in_if_expression(if_expression); - } - noirc_frontend::ast::ExpressionKind::Variable(path) => { - self.find_in_path(path, RequestedItems::AnyItems); - } - noirc_frontend::ast::ExpressionKind::Tuple(expressions) => { - self.find_in_expressions(expressions); - } - noirc_frontend::ast::ExpressionKind::Lambda(lambda) => self.find_in_lambda(lambda), - noirc_frontend::ast::ExpressionKind::Parenthesized(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::ExpressionKind::Unquote(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::ExpressionKind::Comptime(block_expression, _) => { - // When entering a comptime block, regular local variables shouldn't be offered anymore - let old_local_variables = self.local_variables.clone(); - self.local_variables.clear(); - - self.find_in_block_expression(block_expression); - - self.local_variables = old_local_variables; - } - noirc_frontend::ast::ExpressionKind::Unsafe(block_expression, _) => { - self.find_in_block_expression(block_expression); - } - noirc_frontend::ast::ExpressionKind::AsTraitPath(as_trait_path) => { - self.find_in_as_trait_path(as_trait_path); - } - noirc_frontend::ast::ExpressionKind::Quote(_) - | noirc_frontend::ast::ExpressionKind::Resolved(_) - | noirc_frontend::ast::ExpressionKind::Error => (), - } - - // "foo." (no identifier afterwards) is parsed as the expression on the left hand-side of the dot. - // Here we check if there's a dot at the completion position, and if the expression - // ends right before the dot. If so, it means we want to complete the expression's type fields and methods. - // We only do this after visiting nested expressions, because in an expression like `foo & bar.` we want - // to complete for `bar`, not for `foo & bar`. - if self.completion_items.is_empty() - && self.byte == Some(b'.') - && expression.span.end() as usize == self.byte_index - 1 - { - let location = Location::new(expression.span, self.file); - if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); - let prefix = ""; - self.complete_type_fields_and_methods(&typ, prefix); - } - } - } - - fn find_in_literal(&mut self, literal: &Literal) { - match literal { - Literal::Array(array_literal) => self.find_in_array_literal(array_literal), - Literal::Slice(array_literal) => self.find_in_array_literal(array_literal), - Literal::Bool(_) - | Literal::Integer(_, _) - | Literal::Str(_) - | Literal::RawStr(_, _) - | Literal::FmtStr(_) - | Literal::Unit => (), - } - } - - fn find_in_array_literal(&mut self, array_literal: &ArrayLiteral) { - match array_literal { - ArrayLiteral::Standard(expressions) => self.find_in_expressions(expressions), - ArrayLiteral::Repeated { repeated_element, length } => { - self.find_in_expression(repeated_element); - self.find_in_expression(length); - } - } - } - - fn find_in_index_expression(&mut self, index_expression: &IndexExpression) { - self.find_in_expression(&index_expression.collection); - self.find_in_expression(&index_expression.index); - } - - fn find_in_call_expression(&mut self, call_expression: &CallExpression) { - self.find_in_expression(&call_expression.func); - self.find_in_expressions(&call_expression.arguments); - } - - fn find_in_method_call_expression(&mut self, method_call_expression: &MethodCallExpression) { - self.find_in_expression(&method_call_expression.object); - self.find_in_expressions(&method_call_expression.arguments); - } - - fn find_in_constructor_expression(&mut self, constructor_expression: &ConstructorExpression) { - self.find_in_path(&constructor_expression.type_name, RequestedItems::OnlyTypes); - - for (_field_name, expression) in &constructor_expression.fields { - self.find_in_expression(expression); - } - } - - fn find_in_member_access_expression( - &mut self, - member_access_expression: &MemberAccessExpression, - ) { - let ident = &member_access_expression.rhs; - - if self.byte_index == ident.span().end() as usize { - // Assuming member_access_expression is of the form `foo.bar`, we are right after `bar` - let location = Location::new(member_access_expression.lhs.span, self.file); - if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); - let prefix = ident.to_string(); - self.complete_type_fields_and_methods(&typ, &prefix); - return; - } - } - - self.find_in_expression(&member_access_expression.lhs); - } - - fn find_in_cast_expression(&mut self, cast_expression: &CastExpression) { - self.find_in_expression(&cast_expression.lhs); - } - - fn find_in_infix_expression(&mut self, infix_expression: &InfixExpression) { - self.find_in_expression(&infix_expression.lhs); - self.find_in_expression(&infix_expression.rhs); - } - - fn find_in_if_expression(&mut self, if_expression: &IfExpression) { - self.find_in_expression(&if_expression.condition); - - let old_local_variables = self.local_variables.clone(); - self.find_in_expression(&if_expression.consequence); - self.local_variables = old_local_variables; - - if let Some(alternative) = &if_expression.alternative { - let old_local_variables = self.local_variables.clone(); - self.find_in_expression(alternative); - self.local_variables = old_local_variables; - } - } - - fn find_in_lambda(&mut self, lambda: &Lambda) { - for (_, unresolved_type) in &lambda.parameters { - self.find_in_unresolved_type(unresolved_type); - } - - let old_local_variables = self.local_variables.clone(); - for (pattern, _) in &lambda.parameters { - self.collect_local_variables(pattern); - } - - self.find_in_expression(&lambda.body); - - self.local_variables = old_local_variables; - } - - fn find_in_as_trait_path(&mut self, as_trait_path: &AsTraitPath) { - self.find_in_path(&as_trait_path.trait_path, RequestedItems::OnlyTypes); - } - fn find_in_function_return_type(&mut self, return_type: &FunctionReturnType) { match return_type { noirc_frontend::ast::FunctionReturnType::Default(_) => (), @@ -661,48 +193,48 @@ impl<'a> NodeFinder<'a> { } match &unresolved_type.typ { - noirc_frontend::ast::UnresolvedTypeData::Array(_, unresolved_type) => { + UnresolvedTypeData::Array(_, unresolved_type) => { self.find_in_unresolved_type(unresolved_type); } - noirc_frontend::ast::UnresolvedTypeData::Slice(unresolved_type) => { + UnresolvedTypeData::Slice(unresolved_type) => { self.find_in_unresolved_type(unresolved_type); } - noirc_frontend::ast::UnresolvedTypeData::Parenthesized(unresolved_type) => { + UnresolvedTypeData::Parenthesized(unresolved_type) => { self.find_in_unresolved_type(unresolved_type); } - noirc_frontend::ast::UnresolvedTypeData::Named(path, unresolved_types, _) => { + UnresolvedTypeData::Named(path, unresolved_types, _) => { self.find_in_path(path, RequestedItems::OnlyTypes); self.find_in_unresolved_types(unresolved_types); } - noirc_frontend::ast::UnresolvedTypeData::TraitAsType(path, unresolved_types) => { + UnresolvedTypeData::TraitAsType(path, unresolved_types) => { self.find_in_path(path, RequestedItems::OnlyTypes); self.find_in_unresolved_types(unresolved_types); } - noirc_frontend::ast::UnresolvedTypeData::MutableReference(unresolved_type) => { + UnresolvedTypeData::MutableReference(unresolved_type) => { self.find_in_unresolved_type(unresolved_type); } - noirc_frontend::ast::UnresolvedTypeData::Tuple(unresolved_types) => { + UnresolvedTypeData::Tuple(unresolved_types) => { self.find_in_unresolved_types(unresolved_types); } - noirc_frontend::ast::UnresolvedTypeData::Function(args, ret, env, _) => { + UnresolvedTypeData::Function(args, ret, env, _) => { self.find_in_unresolved_types(args); self.find_in_unresolved_type(ret); self.find_in_unresolved_type(env); } - noirc_frontend::ast::UnresolvedTypeData::AsTraitPath(as_trait_path) => { - self.find_in_as_trait_path(as_trait_path); - } - noirc_frontend::ast::UnresolvedTypeData::Expression(_) - | noirc_frontend::ast::UnresolvedTypeData::FormatString(_, _) - | noirc_frontend::ast::UnresolvedTypeData::String(_) - | noirc_frontend::ast::UnresolvedTypeData::Unspecified - | noirc_frontend::ast::UnresolvedTypeData::Quoted(_) - | noirc_frontend::ast::UnresolvedTypeData::FieldElement - | noirc_frontend::ast::UnresolvedTypeData::Integer(_, _) - | noirc_frontend::ast::UnresolvedTypeData::Bool - | noirc_frontend::ast::UnresolvedTypeData::Unit - | noirc_frontend::ast::UnresolvedTypeData::Resolved(_) - | noirc_frontend::ast::UnresolvedTypeData::Error => (), + UnresolvedTypeData::AsTraitPath(as_trait_path) => { + as_trait_path.accept(self); + } + UnresolvedTypeData::Expression(_) + | UnresolvedTypeData::FormatString(_, _) + | UnresolvedTypeData::String(_) + | UnresolvedTypeData::Unspecified + | UnresolvedTypeData::Quoted(_) + | UnresolvedTypeData::FieldElement + | UnresolvedTypeData::Integer(_, _) + | UnresolvedTypeData::Bool + | UnresolvedTypeData::Unit + | UnresolvedTypeData::Resolved(_) + | UnresolvedTypeData::Error => (), } } @@ -1174,6 +706,316 @@ impl<'a> NodeFinder<'a> { } } +impl<'a> Visitor for NodeFinder<'a> { + fn visit_item(&mut self, item: &Item) -> bool { + self.includes_span(item.span) + } + + fn visit_use_tree(&mut self, use_tree: &UseTree) -> bool { + let mut prefixes = Vec::new(); + self.find_in_use_tree(use_tree, &mut prefixes); + false + } + + fn visit_parsed_submodule(&mut self, parsed_sub_module: &ParsedSubModule) -> bool { + // Switch `self.module_id` to the submodule + let previous_module_id = self.module_id; + + let def_map = &self.def_maps[&self.module_id.krate]; + let Some(module_data) = def_map.modules().get(self.module_id.local_id.0) else { + return false; + }; + if let Some(child_module) = module_data.children.get(&parsed_sub_module.name) { + self.module_id = ModuleId { krate: self.module_id.krate, local_id: *child_module }; + } + + parsed_sub_module.contents.accept(self); + + // Restore the old module before continuing + self.module_id = previous_module_id; + + false + } + + fn visit_noir_function(&mut self, noir_function: &NoirFunction) -> bool { + let old_type_parameters = self.type_parameters.clone(); + self.collect_type_parameters_in_generics(&noir_function.def.generics); + + for param in &noir_function.def.parameters { + self.find_in_unresolved_type(¶m.typ); + } + + self.find_in_function_return_type(&noir_function.def.return_type); + + self.local_variables.clear(); + for param in &noir_function.def.parameters { + self.collect_local_variables(¶m.pattern); + } + + noir_function.def.body.accept(self); + + self.type_parameters = old_type_parameters; + + false + } + + fn visit_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl) -> bool { + self.type_parameters.clear(); + self.collect_type_parameters_in_generics(&noir_trait_impl.impl_generics); + + for item in &noir_trait_impl.items { + item.accept(self); + } + + self.type_parameters.clear(); + + false + } + + fn visit_type_impl(&mut self, type_impl: &TypeImpl) -> bool { + self.type_parameters.clear(); + self.collect_type_parameters_in_generics(&type_impl.generics); + + for (method, span) in &type_impl.methods { + method.accept(self); + + // Optimization: stop looking in functions past the completion cursor + if span.end() as usize > self.byte_index { + break; + } + } + + self.type_parameters.clear(); + + false + } + + fn visit_noir_type_alias(&mut self, noir_type_alias: &NoirTypeAlias) { + self.find_in_unresolved_type(&noir_type_alias.typ); + } + + fn visit_noir_struct(&mut self, noir_struct: &NoirStruct) { + self.type_parameters.clear(); + self.collect_type_parameters_in_generics(&noir_struct.generics); + + for (_name, unresolved_type) in &noir_struct.fields { + self.find_in_unresolved_type(unresolved_type); + } + + self.type_parameters.clear(); + } + + fn visit_trait_item(&mut self, trait_item: &TraitItem) -> bool { + match trait_item { + TraitItem::Function { + name: _, + generics, + parameters, + return_type, + where_clause, + body, + } => { + let old_type_parameters = self.type_parameters.clone(); + self.collect_type_parameters_in_generics(generics); + + for (_name, unresolved_type) in parameters { + self.find_in_unresolved_type(unresolved_type); + } + + self.find_in_function_return_type(return_type); + + for unresolved_trait_constraint in where_clause { + self.find_in_unresolved_type(&unresolved_trait_constraint.typ); + } + + if let Some(body) = body { + self.local_variables.clear(); + for (name, _) in parameters { + self.local_variables.insert(name.to_string(), name.span()); + } + body.accept(self); + }; + + self.type_parameters = old_type_parameters; + + false + } + TraitItem::Constant { name: _, typ, default_value: _ } => { + self.find_in_unresolved_type(typ); + + true + } + TraitItem::Type { name: _ } => false, + } + } + + fn visit_block_expression(&mut self, block_expression: &BlockExpression) -> bool { + let old_local_variables = self.local_variables.clone(); + for statement in &block_expression.statements { + statement.accept(self); + + // Optimization: stop looking in statements past the completion cursor + if statement.span.end() as usize > self.byte_index { + break; + } + } + self.local_variables = old_local_variables; + + false + } + + fn visit_comptime_statement(&mut self, statement: &Statement) -> bool { + // When entering a comptime block, regular local variables shouldn't be offered anymore + let old_local_variables = self.local_variables.clone(); + self.local_variables.clear(); + + statement.accept(self); + + self.local_variables = old_local_variables; + + false + } + + fn visit_let_statement(&mut self, let_statement: &LetStatement) -> bool { + self.find_in_let_statement(let_statement, true); + false + } + + fn visit_global(&mut self, let_statement: &LetStatement) -> bool { + self.find_in_let_statement(let_statement, false); + false + } + + fn visit_for_loop_statement(&mut self, for_loop_statement: &ForLoopStatement) -> bool { + let old_local_variables = self.local_variables.clone(); + let ident = &for_loop_statement.identifier; + self.local_variables.insert(ident.to_string(), ident.span()); + + for_loop_statement.accept_children(self); + + self.local_variables = old_local_variables; + + false + } + + fn visit_lvalue_ident(&mut self, ident: &Ident) { + if self.byte == Some(b'.') && ident.span().end() as usize == self.byte_index - 1 { + let location = Location::new(ident.span(), self.file); + if let Some(ReferenceId::Local(definition_id)) = self.interner.find_referenced(location) + { + let typ = self.interner.definition_type(definition_id); + let prefix = ""; + self.complete_type_fields_and_methods(&typ, prefix); + } + } + } + + fn visit_variable(&mut self, path: &Path) { + self.find_in_path(path, RequestedItems::AnyItems); + } + + fn visit_comptime_expression(&mut self, block_expression: &BlockExpression) -> bool { + // When entering a comptime block, regular local variables shouldn't be offered anymore + let old_local_variables = self.local_variables.clone(); + self.local_variables.clear(); + + block_expression.accept(self); + + self.local_variables = old_local_variables; + + false + } + + fn visit_expression(&mut self, expression: &Expression) -> bool { + expression.accept_children(self); + + // "foo." (no identifier afterwards) is parsed as the expression on the left hand-side of the dot. + // Here we check if there's a dot at the completion position, and if the expression + // ends right before the dot. If so, it means we want to complete the expression's type fields and methods. + // We only do this after visiting nested expressions, because in an expression like `foo & bar.` we want + // to complete for `bar`, not for `foo & bar`. + if self.completion_items.is_empty() + && self.byte == Some(b'.') + && expression.span.end() as usize == self.byte_index - 1 + { + let location = Location::new(expression.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = ""; + self.complete_type_fields_and_methods(&typ, prefix); + } + } + + false + } + + fn visit_constructor_expression( + &mut self, + constructor_expression: &ConstructorExpression, + ) -> bool { + self.find_in_path(&constructor_expression.type_name, RequestedItems::OnlyTypes); + + true + } + + fn visit_member_access_expression( + &mut self, + member_access_expression: &MemberAccessExpression, + ) -> bool { + let ident = &member_access_expression.rhs; + + if self.byte_index == ident.span().end() as usize { + // Assuming member_access_expression is of the form `foo.bar`, we are right after `bar` + let location = Location::new(member_access_expression.lhs.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = ident.to_string(); + self.complete_type_fields_and_methods(&typ, &prefix); + return false; + } + } + + true + } + + fn visit_if_expression(&mut self, if_expression: &IfExpression) -> bool { + if_expression.condition.accept(self); + + let old_local_variables = self.local_variables.clone(); + if_expression.consequence.accept(self); + self.local_variables = old_local_variables; + + if let Some(alternative) = &if_expression.alternative { + let old_local_variables = self.local_variables.clone(); + alternative.accept(self); + self.local_variables = old_local_variables; + } + + false + } + + fn visit_lambda(&mut self, lambda: &Lambda) -> bool { + for (_, unresolved_type) in &lambda.parameters { + self.find_in_unresolved_type(unresolved_type); + } + + let old_local_variables = self.local_variables.clone(); + for (pattern, _) in &lambda.parameters { + self.collect_local_variables(pattern); + } + + lambda.body.accept(self); + + self.local_variables = old_local_variables; + + false + } + + fn visit_as_trait_path(&mut self, as_trait_path: &AsTraitPath) { + self.find_in_path(&as_trait_path.trait_path, RequestedItems::OnlyTypes); + } +} + fn name_matches(name: &str, prefix: &str) -> bool { name.starts_with(prefix) } diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index c2c69185547..cc0aeb56873 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -7,19 +7,25 @@ use lsp_types::{ }; use noirc_errors::{Location, Span}; use noirc_frontend::{ - ast::{CallExpression, Expression, FunctionReturnType, MethodCallExpression}, + ast::{ + CallExpression, Expression, ExpressionKind, FunctionReturnType, MethodCallExpression, + Statement, + }, hir_def::{function::FuncMeta, stmt::HirPattern}, macros_api::NodeInterner, node_interner::ReferenceId, ParsedModule, Type, }; -use crate::{utils, LspState}; +use crate::{ + utils, + visitor::{Acceptor, ChildrenAcceptor, Visitor}, + LspState, +}; use super::process_request; mod tests; -mod traversal; pub(crate) fn on_signature_help_request( state: &mut LspState, @@ -61,15 +67,12 @@ impl<'a> SignatureFinder<'a> { } fn find(&mut self, parsed_module: &ParsedModule) -> Option { - self.find_in_parsed_module(parsed_module); + parsed_module.accept(self); self.signature_help.clone() } fn find_in_call_expression(&mut self, call_expression: &CallExpression, span: Span) { - self.find_in_expression(&call_expression.func); - self.find_in_expressions(&call_expression.arguments); - let arguments_span = Span::from(call_expression.func.span.end() + 1..span.end() - 1); let span = call_expression.func.span; let name_span = Span::from(span.end() - 1..span.end()); @@ -88,9 +91,6 @@ impl<'a> SignatureFinder<'a> { method_call_expression: &MethodCallExpression, span: Span, ) { - self.find_in_expression(&method_call_expression.object); - self.find_in_expressions(&method_call_expression.arguments); - let arguments_span = Span::from(method_call_expression.method_name.span().end() + 1..span.end() - 1); let name_span = method_call_expression.method_name.span(); @@ -289,3 +289,33 @@ impl<'a> SignatureFinder<'a> { span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize } } + +impl<'a> Visitor for SignatureFinder<'a> { + fn visit_expression(&mut self, expression: &Expression) -> bool { + if !self.includes_span(expression.span) { + return false; + } + + match &expression.kind { + ExpressionKind::Call(call_expression) => { + call_expression.accept_children(self); + + self.find_in_call_expression(call_expression, expression.span); + + false + } + ExpressionKind::MethodCall(method_call_expression) => { + method_call_expression.accept_children(self); + + self.find_in_method_call_expression(method_call_expression, expression.span); + + false + } + _ => true, + } + } + + fn visit_statement(&mut self, statement: &Statement) -> bool { + self.includes_span(statement.span) + } +} diff --git a/tooling/lsp/src/requests/signature_help/traversal.rs b/tooling/lsp/src/requests/signature_help/traversal.rs deleted file mode 100644 index ecb3bf46487..00000000000 --- a/tooling/lsp/src/requests/signature_help/traversal.rs +++ /dev/null @@ -1,308 +0,0 @@ -/// This file includes the signature help logic that's just about -/// traversing the AST without any additional logic. -use super::SignatureFinder; - -use noirc_frontend::{ - ast::{ - ArrayLiteral, BlockExpression, CastExpression, ConstrainStatement, ConstructorExpression, - Expression, ForLoopStatement, ForRange, IfExpression, IndexExpression, InfixExpression, - LValue, Lambda, LetStatement, Literal, MemberAccessExpression, NoirFunction, NoirTrait, - NoirTraitImpl, Statement, TraitImplItem, TraitItem, TypeImpl, - }, - parser::{Item, ItemKind}, - ParsedModule, -}; - -impl<'a> SignatureFinder<'a> { - pub(super) fn find_in_parsed_module(&mut self, parsed_module: &ParsedModule) { - for item in &parsed_module.items { - self.find_in_item(item); - } - } - - pub(super) fn find_in_item(&mut self, item: &Item) { - if !self.includes_span(item.span) { - return; - } - - match &item.kind { - ItemKind::Submodules(parsed_sub_module) => { - self.find_in_parsed_module(&parsed_sub_module.contents); - } - ItemKind::Function(noir_function) => self.find_in_noir_function(noir_function), - ItemKind::TraitImpl(noir_trait_impl) => self.find_in_noir_trait_impl(noir_trait_impl), - ItemKind::Impl(type_impl) => self.find_in_type_impl(type_impl), - ItemKind::Global(let_statement) => self.find_in_let_statement(let_statement), - ItemKind::Trait(noir_trait) => self.find_in_noir_trait(noir_trait), - ItemKind::Import(..) - | ItemKind::TypeAlias(_) - | ItemKind::Struct(_) - | ItemKind::ModuleDecl(_) => (), - } - } - - pub(super) fn find_in_noir_function(&mut self, noir_function: &NoirFunction) { - self.find_in_block_expression(&noir_function.def.body); - } - - pub(super) fn find_in_noir_trait_impl(&mut self, noir_trait_impl: &NoirTraitImpl) { - for item in &noir_trait_impl.items { - self.find_in_trait_impl_item(item); - } - } - - pub(super) fn find_in_trait_impl_item(&mut self, item: &TraitImplItem) { - match item { - TraitImplItem::Function(noir_function) => self.find_in_noir_function(noir_function), - TraitImplItem::Constant(_, _, _) => (), - TraitImplItem::Type { .. } => (), - } - } - - pub(super) fn find_in_type_impl(&mut self, type_impl: &TypeImpl) { - for (method, span) in &type_impl.methods { - if self.includes_span(*span) { - self.find_in_noir_function(method); - } - } - } - - pub(super) fn find_in_noir_trait(&mut self, noir_trait: &NoirTrait) { - for item in &noir_trait.items { - self.find_in_trait_item(item); - } - } - - pub(super) fn find_in_trait_item(&mut self, trait_item: &TraitItem) { - match trait_item { - TraitItem::Function { body, .. } => { - if let Some(body) = body { - self.find_in_block_expression(body); - }; - } - TraitItem::Constant { default_value, .. } => { - if let Some(default_value) = default_value { - self.find_in_expression(default_value); - } - } - TraitItem::Type { .. } => (), - } - } - - pub(super) fn find_in_block_expression(&mut self, block_expression: &BlockExpression) { - for statement in &block_expression.statements { - if self.includes_span(statement.span) { - self.find_in_statement(statement); - } - } - } - - pub(super) fn find_in_statement(&mut self, statement: &Statement) { - if !self.includes_span(statement.span) { - return; - } - - match &statement.kind { - noirc_frontend::ast::StatementKind::Let(let_statement) => { - self.find_in_let_statement(let_statement); - } - noirc_frontend::ast::StatementKind::Constrain(constrain_statement) => { - self.find_in_constrain_statement(constrain_statement); - } - noirc_frontend::ast::StatementKind::Expression(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::StatementKind::Assign(assign_statement) => { - self.find_in_assign_statement(assign_statement); - } - noirc_frontend::ast::StatementKind::For(for_loop_statement) => { - self.find_in_for_loop_statement(for_loop_statement); - } - noirc_frontend::ast::StatementKind::Comptime(statement) => { - self.find_in_statement(statement); - } - noirc_frontend::ast::StatementKind::Semi(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::StatementKind::Break - | noirc_frontend::ast::StatementKind::Continue - | noirc_frontend::ast::StatementKind::Error => (), - } - } - - pub(super) fn find_in_let_statement(&mut self, let_statement: &LetStatement) { - self.find_in_expression(&let_statement.expression); - } - - pub(super) fn find_in_constrain_statement(&mut self, constrain_statement: &ConstrainStatement) { - self.find_in_expression(&constrain_statement.0); - - if let Some(exp) = &constrain_statement.1 { - self.find_in_expression(exp); - } - } - - pub(super) fn find_in_assign_statement( - &mut self, - assign_statement: &noirc_frontend::ast::AssignStatement, - ) { - self.find_in_lvalue(&assign_statement.lvalue); - self.find_in_expression(&assign_statement.expression); - } - - pub(super) fn find_in_for_loop_statement(&mut self, for_loop_statement: &ForLoopStatement) { - self.find_in_for_range(&for_loop_statement.range); - self.find_in_expression(&for_loop_statement.block); - } - - pub(super) fn find_in_lvalue(&mut self, lvalue: &LValue) { - match lvalue { - LValue::Ident(_) => (), - LValue::MemberAccess { object, field_name: _, span: _ } => self.find_in_lvalue(object), - LValue::Index { array, index, span: _ } => { - self.find_in_lvalue(array); - self.find_in_expression(index); - } - LValue::Dereference(lvalue, _) => self.find_in_lvalue(lvalue), - } - } - - pub(super) fn find_in_for_range(&mut self, for_range: &ForRange) { - match for_range { - ForRange::Range(start, end) => { - self.find_in_expression(start); - self.find_in_expression(end); - } - ForRange::Array(expression) => self.find_in_expression(expression), - } - } - - pub(super) fn find_in_expressions(&mut self, expressions: &[Expression]) { - for expression in expressions { - self.find_in_expression(expression); - } - } - - pub(super) fn find_in_expression(&mut self, expression: &Expression) { - match &expression.kind { - noirc_frontend::ast::ExpressionKind::Literal(literal) => self.find_in_literal(literal), - noirc_frontend::ast::ExpressionKind::Block(block_expression) => { - self.find_in_block_expression(block_expression); - } - noirc_frontend::ast::ExpressionKind::Prefix(prefix_expression) => { - self.find_in_expression(&prefix_expression.rhs); - } - noirc_frontend::ast::ExpressionKind::Index(index_expression) => { - self.find_in_index_expression(index_expression); - } - noirc_frontend::ast::ExpressionKind::Call(call_expression) => { - self.find_in_call_expression(call_expression, expression.span); - } - noirc_frontend::ast::ExpressionKind::MethodCall(method_call_expression) => { - self.find_in_method_call_expression(method_call_expression, expression.span); - } - noirc_frontend::ast::ExpressionKind::Constructor(constructor_expression) => { - self.find_in_constructor_expression(constructor_expression); - } - noirc_frontend::ast::ExpressionKind::MemberAccess(member_access_expression) => { - self.find_in_member_access_expression(member_access_expression); - } - noirc_frontend::ast::ExpressionKind::Cast(cast_expression) => { - self.find_in_cast_expression(cast_expression); - } - noirc_frontend::ast::ExpressionKind::Infix(infix_expression) => { - self.find_in_infix_expression(infix_expression); - } - noirc_frontend::ast::ExpressionKind::If(if_expression) => { - self.find_in_if_expression(if_expression); - } - noirc_frontend::ast::ExpressionKind::Tuple(expressions) => { - self.find_in_expressions(expressions); - } - noirc_frontend::ast::ExpressionKind::Lambda(lambda) => self.find_in_lambda(lambda), - noirc_frontend::ast::ExpressionKind::Parenthesized(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::ExpressionKind::Unquote(expression) => { - self.find_in_expression(expression); - } - noirc_frontend::ast::ExpressionKind::Comptime(block_expression, _) => { - self.find_in_block_expression(block_expression); - } - noirc_frontend::ast::ExpressionKind::Unsafe(block_expression, _) => { - self.find_in_block_expression(block_expression); - } - noirc_frontend::ast::ExpressionKind::Variable(_) - | noirc_frontend::ast::ExpressionKind::AsTraitPath(_) - | noirc_frontend::ast::ExpressionKind::Quote(_) - | noirc_frontend::ast::ExpressionKind::Resolved(_) - | noirc_frontend::ast::ExpressionKind::Error => (), - } - } - - pub(super) fn find_in_literal(&mut self, literal: &Literal) { - match literal { - Literal::Array(array_literal) => self.find_in_array_literal(array_literal), - Literal::Slice(array_literal) => self.find_in_array_literal(array_literal), - Literal::Bool(_) - | Literal::Integer(_, _) - | Literal::Str(_) - | Literal::RawStr(_, _) - | Literal::FmtStr(_) - | Literal::Unit => (), - } - } - - pub(super) fn find_in_array_literal(&mut self, array_literal: &ArrayLiteral) { - match array_literal { - ArrayLiteral::Standard(expressions) => self.find_in_expressions(expressions), - ArrayLiteral::Repeated { repeated_element, length } => { - self.find_in_expression(repeated_element); - self.find_in_expression(length); - } - } - } - - pub(super) fn find_in_index_expression(&mut self, index_expression: &IndexExpression) { - self.find_in_expression(&index_expression.collection); - self.find_in_expression(&index_expression.index); - } - - pub(super) fn find_in_constructor_expression( - &mut self, - constructor_expression: &ConstructorExpression, - ) { - for (_field_name, expression) in &constructor_expression.fields { - self.find_in_expression(expression); - } - } - - pub(super) fn find_in_member_access_expression( - &mut self, - member_access_expression: &MemberAccessExpression, - ) { - self.find_in_expression(&member_access_expression.lhs); - } - - pub(super) fn find_in_cast_expression(&mut self, cast_expression: &CastExpression) { - self.find_in_expression(&cast_expression.lhs); - } - - pub(super) fn find_in_infix_expression(&mut self, infix_expression: &InfixExpression) { - self.find_in_expression(&infix_expression.lhs); - self.find_in_expression(&infix_expression.rhs); - } - - pub(super) fn find_in_if_expression(&mut self, if_expression: &IfExpression) { - self.find_in_expression(&if_expression.condition); - self.find_in_expression(&if_expression.consequence); - - if let Some(alternative) = &if_expression.alternative { - self.find_in_expression(alternative); - } - } - - pub(super) fn find_in_lambda(&mut self, lambda: &Lambda) { - self.find_in_expression(&lambda.body); - } -} diff --git a/tooling/lsp/src/visitor.rs b/tooling/lsp/src/visitor.rs new file mode 100644 index 00000000000..d71eb5e1982 --- /dev/null +++ b/tooling/lsp/src/visitor.rs @@ -0,0 +1,850 @@ +use noirc_frontend::{ + ast::{ + ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression, + CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind, + ForLoopStatement, ForRange, Ident, IfExpression, IndexExpression, InfixExpression, LValue, + Lambda, LetStatement, Literal, MemberAccessExpression, MethodCallExpression, + ModuleDeclaration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, + PrefixExpression, Statement, StatementKind, TraitImplItem, TraitItem, TypeImpl, UseTree, + UseTreeKind, + }, + parser::{Item, ItemKind, ParsedSubModule}, + ParsedModule, +}; + +/// Implements the [Visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for Noir's AST. +/// +/// In this implementation, methods must return a bool: +/// - true means children must be visited +/// - false means children must not be visited, either because the visitor implementation +/// will visit children of interest manually, or because no children are of interest +pub(crate) trait Visitor { + fn visit_parsed_module(&mut self, _: &ParsedModule) -> bool { + true + } + + fn visit_item(&mut self, _: &Item) -> bool { + true + } + + fn visit_parsed_submodule(&mut self, _: &ParsedSubModule) -> bool { + true + } + + fn visit_noir_function(&mut self, _: &NoirFunction) -> bool { + true + } + + fn visit_noir_trait_impl(&mut self, _: &NoirTraitImpl) -> bool { + true + } + + fn visit_type_impl(&mut self, _: &TypeImpl) -> bool { + true + } + + fn visit_trait_impl_item(&mut self, _: &TraitImplItem) -> bool { + true + } + + fn visit_noir_trait(&mut self, _: &NoirTrait) -> bool { + true + } + + fn visit_trait_item(&mut self, _: &TraitItem) -> bool { + true + } + + fn visit_use_tree(&mut self, _: &UseTree) -> bool { + true + } + + fn visit_noir_struct(&mut self, _: &NoirStruct) {} + + fn visit_noir_type_alias(&mut self, _: &NoirTypeAlias) {} + + fn visit_module_declaration(&mut self, _: &ModuleDeclaration) {} + + fn visit_expression(&mut self, _: &Expression) -> bool { + true + } + + fn visit_literal(&mut self, _: &Literal) -> bool { + true + } + + fn visit_block_expression(&mut self, _: &BlockExpression) -> bool { + true + } + + fn visit_prefix_expression(&mut self, _: &PrefixExpression) -> bool { + true + } + + fn visit_index_expression(&mut self, _: &IndexExpression) -> bool { + true + } + + fn visit_call_expression(&mut self, _: &CallExpression) -> bool { + true + } + + fn visit_method_call_expression(&mut self, _: &MethodCallExpression) -> bool { + true + } + + fn visit_constructor_expression(&mut self, _: &ConstructorExpression) -> bool { + true + } + + fn visit_member_access_expression(&mut self, _: &MemberAccessExpression) -> bool { + true + } + + fn visit_cast_expression(&mut self, _: &CastExpression) -> bool { + true + } + + fn visit_infix_expression(&mut self, _: &InfixExpression) -> bool { + true + } + + fn visit_if_expression(&mut self, _: &IfExpression) -> bool { + true + } + + fn visit_tuple(&mut self, _: &[Expression]) -> bool { + true + } + + fn visit_parenthesized(&mut self, _: &Expression) -> bool { + true + } + + fn visit_unquote(&mut self, _: &Expression) -> bool { + true + } + + fn visit_comptime_expression(&mut self, _: &BlockExpression) -> bool { + true + } + + fn visit_unsafe(&mut self, _: &BlockExpression) -> bool { + true + } + + fn visit_variable(&mut self, _: &Path) {} + + fn visit_lambda(&mut self, _: &Lambda) -> bool { + true + } + + fn visit_array_literal(&mut self, _: &ArrayLiteral) -> bool { + true + } + + fn visit_statement(&mut self, _: &Statement) -> bool { + true + } + + fn visit_global(&mut self, _: &LetStatement) -> bool { + true + } + + fn visit_let_statement(&mut self, _: &LetStatement) -> bool { + true + } + + fn visit_constrain_statement(&mut self, _: &ConstrainStatement) -> bool { + true + } + + fn visit_assign_statement(&mut self, _: &AssignStatement) -> bool { + true + } + + fn visit_for_loop_statement(&mut self, _: &ForLoopStatement) -> bool { + true + } + + fn visit_comptime_statement(&mut self, _: &Statement) -> bool { + true + } + + fn visit_lvalue(&mut self, _: &LValue) -> bool { + true + } + + fn visit_lvalue_ident(&mut self, _: &Ident) {} + + fn visit_for_range(&mut self, _: &ForRange) -> bool { + true + } + + fn visit_as_trait_path(&mut self, _: &AsTraitPath) {} +} + +pub(crate) trait Acceptor { + fn accept(&self, visitor: &mut impl Visitor); +} + +pub(crate) trait ChildrenAcceptor { + fn accept_children(&self, visitor: &mut impl Visitor); +} + +impl Acceptor for ParsedModule { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_parsed_module(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for ParsedModule { + fn accept_children(&self, visitor: &mut impl Visitor) { + for item in &self.items { + item.accept(visitor); + } + } +} + +impl Acceptor for Item { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_item(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for Item { + fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + ItemKind::Submodules(parsed_sub_module) => { + parsed_sub_module.accept(visitor); + } + ItemKind::Function(noir_function) => noir_function.accept(visitor), + ItemKind::TraitImpl(noir_trait_impl) => { + noir_trait_impl.accept(visitor); + } + ItemKind::Impl(type_impl) => type_impl.accept(visitor), + ItemKind::Global(let_statement) => { + if visitor.visit_global(let_statement) { + let_statement.accept(visitor) + } + } + ItemKind::Trait(noir_trait) => noir_trait.accept(visitor), + ItemKind::Import(use_tree) => use_tree.accept(visitor), + ItemKind::TypeAlias(noir_type_alias) => noir_type_alias.accept(visitor), + ItemKind::Struct(noir_struct) => noir_struct.accept(visitor), + ItemKind::ModuleDecl(module_declaration) => module_declaration.accept(visitor), + } + } +} + +impl Acceptor for ParsedSubModule { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_parsed_submodule(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for ParsedSubModule { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.contents.accept(visitor); + } +} + +impl Acceptor for NoirFunction { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_noir_function(self) { + self.accept_children(visitor); + } + } +} +impl ChildrenAcceptor for NoirFunction { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.def.body.accept(visitor); + } +} + +impl Acceptor for NoirTraitImpl { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_noir_trait_impl(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for NoirTraitImpl { + fn accept_children(&self, visitor: &mut impl Visitor) { + for item in &self.items { + item.accept(visitor); + } + } +} + +impl Acceptor for TraitImplItem { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_trait_impl_item(self) { + self.accept_children(visitor); + } + } +} +impl ChildrenAcceptor for TraitImplItem { + fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + TraitImplItem::Function(noir_function) => noir_function.accept(visitor), + TraitImplItem::Constant(..) => (), + TraitImplItem::Type { .. } => (), + } + } +} + +impl Acceptor for TypeImpl { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_type_impl(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for TypeImpl { + fn accept_children(&self, visitor: &mut impl Visitor) { + for (method, _span) in &self.methods { + method.accept(visitor); + } + } +} + +impl Acceptor for NoirTrait { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_noir_trait(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for NoirTrait { + fn accept_children(&self, visitor: &mut impl Visitor) { + for item in &self.items { + item.accept(visitor); + } + } +} + +impl Acceptor for TraitItem { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_trait_item(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for TraitItem { + fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + TraitItem::Function { + name: _, + generics: _, + parameters: _, + return_type: _, + where_clause: _, + body, + } => { + if let Some(body) = body { + body.accept(visitor); + } + } + TraitItem::Constant { name: _, typ: _, default_value } => { + if let Some(default_value) = default_value { + default_value.accept(visitor); + } + } + TraitItem::Type { name: _ } => (), + } + } +} + +impl Acceptor for UseTree { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_use_tree(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for UseTree { + fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + UseTreeKind::Path(..) => (), + UseTreeKind::List(use_trees) => { + for use_tree in use_trees { + use_tree.accept(visitor); + } + } + } + } +} + +impl Acceptor for NoirStruct { + fn accept(&self, visitor: &mut impl Visitor) { + visitor.visit_noir_struct(self); + } +} + +impl Acceptor for NoirTypeAlias { + fn accept(&self, visitor: &mut impl Visitor) { + visitor.visit_noir_type_alias(self); + } +} + +impl Acceptor for ModuleDeclaration { + fn accept(&self, visitor: &mut impl Visitor) { + visitor.visit_module_declaration(self); + } +} + +impl Acceptor for Expression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for Expression { + fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + ExpressionKind::Literal(literal) => literal.accept(visitor), + ExpressionKind::Block(block_expression) => { + block_expression.accept(visitor); + } + ExpressionKind::Prefix(prefix_expression) => { + prefix_expression.accept(visitor); + } + ExpressionKind::Index(index_expression) => { + index_expression.accept(visitor); + } + ExpressionKind::Call(call_expression) => { + call_expression.accept(visitor); + } + ExpressionKind::MethodCall(method_call_expression) => { + method_call_expression.accept(visitor); + } + ExpressionKind::Constructor(constructor_expression) => { + constructor_expression.accept(visitor); + } + ExpressionKind::MemberAccess(member_access_expression) => { + member_access_expression.accept(visitor); + } + ExpressionKind::Cast(cast_expression) => { + cast_expression.accept(visitor); + } + ExpressionKind::Infix(infix_expression) => { + infix_expression.accept(visitor); + } + ExpressionKind::If(if_expression) => { + if_expression.accept(visitor); + } + ExpressionKind::Tuple(expressions) => { + if visitor.visit_tuple(expressions) { + visit_expressions(expressions, visitor); + } + } + ExpressionKind::Lambda(lambda) => lambda.accept(visitor), + ExpressionKind::Parenthesized(expression) => { + if visitor.visit_parenthesized(expression) { + expression.accept(visitor); + } + } + ExpressionKind::Unquote(expression) => { + if visitor.visit_unquote(expression) { + expression.accept(visitor); + } + } + ExpressionKind::Comptime(block_expression, _) => { + if visitor.visit_comptime_expression(block_expression) { + block_expression.accept(visitor); + } + } + ExpressionKind::Unsafe(block_expression, _) => { + if visitor.visit_unsafe(block_expression) { + block_expression.accept(visitor); + } + } + ExpressionKind::Variable(path) => { + visitor.visit_variable(path); + } + ExpressionKind::AsTraitPath(as_trait_path) => { + as_trait_path.accept(visitor); + } + ExpressionKind::Quote(_) | ExpressionKind::Resolved(_) | ExpressionKind::Error => (), + } + } +} + +impl Acceptor for Literal { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_literal(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for Literal { + fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + Literal::Array(array_literal) | Literal::Slice(array_literal) => { + array_literal.accept(visitor); + } + Literal::Bool(_) + | Literal::Integer(_, _) + | Literal::Str(_) + | Literal::RawStr(_, _) + | Literal::FmtStr(_) + | Literal::Unit => (), + } + } +} + +impl Acceptor for BlockExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_block_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for BlockExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + for statement in &self.statements { + statement.accept(visitor); + } + } +} + +impl Acceptor for PrefixExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_prefix_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for PrefixExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.rhs.accept(visitor); + } +} + +impl Acceptor for IndexExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_index_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for IndexExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.collection.accept(visitor); + self.index.accept(visitor); + } +} + +impl Acceptor for CallExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_call_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for CallExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.func.accept(visitor); + visit_expressions(&self.arguments, visitor); + } +} + +impl Acceptor for MethodCallExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_method_call_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for MethodCallExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.object.accept(visitor); + visit_expressions(&self.arguments, visitor); + } +} + +impl Acceptor for ConstructorExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_constructor_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for ConstructorExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + for (_field_name, expression) in &self.fields { + expression.accept(visitor); + } + } +} + +impl Acceptor for MemberAccessExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_member_access_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for MemberAccessExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.lhs.accept(visitor); + } +} + +impl Acceptor for CastExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_cast_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for CastExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.lhs.accept(visitor); + } +} + +impl Acceptor for InfixExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_infix_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for InfixExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.lhs.accept(visitor); + self.rhs.accept(visitor); + } +} + +impl Acceptor for IfExpression { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_if_expression(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for IfExpression { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.condition.accept(visitor); + self.consequence.accept(visitor); + if let Some(alternative) = &self.alternative { + alternative.accept(visitor); + } + } +} + +impl Acceptor for Lambda { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_lambda(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for Lambda { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.body.accept(visitor); + } +} + +impl Acceptor for ArrayLiteral { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_array_literal(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for ArrayLiteral { + fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + ArrayLiteral::Standard(expressions) => visit_expressions(expressions, visitor), + ArrayLiteral::Repeated { repeated_element, length } => { + repeated_element.accept(visitor); + length.accept(visitor); + } + } + } +} + +impl Acceptor for Statement { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_statement(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for Statement { + fn accept_children(&self, visitor: &mut impl Visitor) { + match &self.kind { + StatementKind::Let(let_statement) => { + let_statement.accept(visitor); + } + StatementKind::Constrain(constrain_statement) => { + constrain_statement.accept(visitor); + } + StatementKind::Expression(expression) => { + expression.accept(visitor); + } + StatementKind::Assign(assign_statement) => { + assign_statement.accept(visitor); + } + StatementKind::For(for_loop_statement) => { + for_loop_statement.accept(visitor); + } + StatementKind::Comptime(statement) => { + if visitor.visit_comptime_statement(statement) { + statement.accept(visitor); + } + } + StatementKind::Semi(expression) => { + expression.accept(visitor); + } + StatementKind::Break | StatementKind::Continue | StatementKind::Error => (), + } + } +} + +impl Acceptor for LetStatement { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_let_statement(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for LetStatement { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.expression.accept(visitor); + } +} + +impl Acceptor for ConstrainStatement { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_constrain_statement(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for ConstrainStatement { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.0.accept(visitor); + + if let Some(exp) = &self.1 { + exp.accept(visitor); + } + } +} + +impl Acceptor for AssignStatement { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_assign_statement(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for AssignStatement { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.lvalue.accept(visitor); + self.expression.accept(visitor); + } +} + +impl Acceptor for ForLoopStatement { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_for_loop_statement(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for ForLoopStatement { + fn accept_children(&self, visitor: &mut impl Visitor) { + self.range.accept(visitor); + self.block.accept(visitor); + } +} + +impl Acceptor for LValue { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_lvalue(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for LValue { + fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + LValue::Ident(ident) => visitor.visit_lvalue_ident(ident), + LValue::MemberAccess { object, field_name: _, span: _ } => object.accept(visitor), + LValue::Index { array, index, span: _ } => { + array.accept(visitor); + index.accept(visitor); + } + LValue::Dereference(lvalue, _) => lvalue.accept(visitor), + } + } +} + +impl Acceptor for ForRange { + fn accept(&self, visitor: &mut impl Visitor) { + if visitor.visit_for_range(self) { + self.accept_children(visitor); + } + } +} + +impl ChildrenAcceptor for ForRange { + fn accept_children(&self, visitor: &mut impl Visitor) { + match self { + ForRange::Range(start, end) => { + start.accept(visitor); + end.accept(visitor); + } + ForRange::Array(expression) => expression.accept(visitor), + } + } +} + +impl Acceptor for AsTraitPath { + fn accept(&self, visitor: &mut impl Visitor) { + visitor.visit_as_trait_path(self); + } +} + +fn visit_expressions(expressions: &[Expression], visitor: &mut impl Visitor) { + for expression in expressions { + expression.accept(visitor); + } +}