diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index d8be6c72aec..b85ab483860 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -17,7 +17,7 @@ use noirc_frontend::{ ast::{ AsTraitPath, AttributeTarget, BlockExpression, CallExpression, ConstructorExpression, Expression, ExpressionKind, ForLoopStatement, GenericTypeArgs, Ident, IfExpression, - ItemVisibility, Lambda, LetStatement, MemberAccessExpression, MethodCallExpression, + ItemVisibility, LValue, Lambda, LetStatement, MemberAccessExpression, MethodCallExpression, NoirFunction, NoirStruct, NoirTraitImpl, Path, PathKind, Pattern, Statement, TraitImplItemKind, TypeImpl, UnresolvedGeneric, UnresolvedGenerics, UnresolvedType, UnresolvedTypeData, UseTree, UseTreeKind, Visitor, @@ -29,7 +29,7 @@ use noirc_frontend::{ node_interner::ReferenceId, parser::{Item, ItemKind, ParsedSubModule}, token::CustomAttribute, - ParsedModule, StructType, Type, + ParsedModule, StructType, Type, TypeBinding, }; use sort_text::underscore_sort_text; @@ -551,6 +551,7 @@ impl<'a> NodeFinder<'a> { function_completion_kind: FunctionCompletionKind, self_prefix: bool, ) { + let typ = &typ; match typ { Type::Struct(struct_type, generics) => { self.complete_struct_fields(&struct_type.borrow(), generics, prefix, self_prefix); @@ -575,6 +576,16 @@ impl<'a> NodeFinder<'a> { Type::Tuple(types) => { self.complete_tuple_fields(types, self_prefix); } + Type::TypeVariable(var, _) | Type::NamedGeneric(var, _, _) => { + if let TypeBinding::Bound(typ) = &*var.borrow() { + self.complete_type_fields_and_methods( + typ, + prefix, + function_completion_kind, + self_prefix, + ); + } + } Type::FieldElement | Type::Array(_, _) | Type::Slice(_) @@ -583,9 +594,7 @@ impl<'a> NodeFinder<'a> { | Type::String(_) | Type::FmtString(_, _) | Type::Unit - | Type::TypeVariable(_, _) | Type::TraitAsType(_, _, _) - | Type::NamedGeneric(_, _, _) | Type::Function(..) | Type::Forall(_, _) | Type::Constant(_) @@ -932,7 +941,8 @@ impl<'a> NodeFinder<'a> { if let Some(ReferenceId::Local(definition_id)) = self.interner.find_referenced(location) { - self.self_type = Some(self.interner.definition_type(definition_id)); + self.self_type = + Some(self.interner.definition_type(definition_id).follow_bindings()); } } } @@ -941,6 +951,32 @@ impl<'a> NodeFinder<'a> { } } + fn get_lvalue_type(&self, lvalue: &LValue) -> Option { + match lvalue { + LValue::Ident(ident) => { + let location = Location::new(ident.span(), self.file); + if let Some(ReferenceId::Local(definition_id)) = + self.interner.find_referenced(location) + { + let typ = self.interner.definition_type(definition_id); + Some(typ) + } else { + None + } + } + LValue::MemberAccess { object, field_name, .. } => { + let typ = self.get_lvalue_type(object)?; + get_field_type(&typ, &field_name.0.contents) + } + LValue::Index { array, .. } => { + let typ = self.get_lvalue_type(array)?; + get_array_element_type(typ) + } + LValue::Dereference(lvalue, ..) => self.get_lvalue_type(lvalue), + LValue::Interned(..) => None, + } + } + fn includes_span(&self, span: Span) -> bool { span.start() as usize <= self.byte_index && self.byte_index <= span.end() as usize } @@ -1153,7 +1189,6 @@ impl<'a> Visitor for NodeFinder<'a> { if after_dot && call_expression.func.span.end() as usize == self.byte_index - 1 { let location = Location::new(call_expression.func.span, self.file); if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); let prefix = ""; let self_prefix = false; self.complete_type_fields_and_methods( @@ -1184,7 +1219,6 @@ impl<'a> Visitor for NodeFinder<'a> { if self.includes_span(method_call_expression.method_name.span()) { let location = Location::new(method_call_expression.object.span, self.file); if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); let prefix = method_call_expression.method_name.to_string(); let offset = self.byte_index - method_call_expression.method_name.span().start() as usize; @@ -1258,6 +1292,7 @@ impl<'a> Visitor for NodeFinder<'a> { } fn visit_lvalue_ident(&mut self, ident: &Ident) { + // If we have `foo.>|<` we suggest `foo`'s type fields and methods if self.byte == Some(b'.') && ident.span().end() as usize == self.byte_index - 1 { let location = Location::new(ident.span(), self.file); if let Some(ReferenceId::Local(definition_id)) = self.interner.find_referenced(location) @@ -1275,6 +1310,72 @@ impl<'a> Visitor for NodeFinder<'a> { } } + fn visit_lvalue_member_access( + &mut self, + object: &LValue, + field_name: &Ident, + span: Span, + ) -> bool { + // If we have `foo.bar.>|<` we solve the type of `foo`, get the field `bar`, + // then suggest methods of the resulting type. + if self.byte == Some(b'.') && span.end() as usize == self.byte_index - 1 { + if let Some(typ) = self.get_lvalue_type(object) { + if let Some(typ) = get_field_type(&typ, &field_name.0.contents) { + let prefix = ""; + let self_prefix = false; + self.complete_type_fields_and_methods( + &typ, + prefix, + FunctionCompletionKind::NameAndParameters, + self_prefix, + ); + } + } + + return false; + } + true + } + + fn visit_lvalue_index(&mut self, array: &LValue, _index: &Expression, span: Span) -> bool { + // If we have `foo[index].>|<` we solve the type of `foo`, then get the array/slice element type, + // then suggest methods of that type. + if self.byte == Some(b'.') && span.end() as usize == self.byte_index - 1 { + if let Some(typ) = self.get_lvalue_type(array) { + if let Some(typ) = get_array_element_type(typ) { + let prefix = ""; + let self_prefix = false; + self.complete_type_fields_and_methods( + &typ, + prefix, + FunctionCompletionKind::NameAndParameters, + self_prefix, + ); + } + } + return false; + } + true + } + + fn visit_lvalue_dereference(&mut self, lvalue: &LValue, span: Span) -> bool { + if self.byte == Some(b'.') && span.end() as usize == self.byte_index - 1 { + if let Some(typ) = self.get_lvalue_type(lvalue) { + let prefix = ""; + let self_prefix = false; + self.complete_type_fields_and_methods( + &typ, + prefix, + FunctionCompletionKind::NameAndParameters, + self_prefix, + ); + } + return false; + } + + true + } + fn visit_variable(&mut self, path: &Path, _: Span) -> bool { self.find_in_path(path, RequestedItems::AnyItems); false @@ -1294,7 +1395,6 @@ impl<'a> Visitor for NodeFinder<'a> { { let location = Location::new(expression.span, self.file); if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); let prefix = ""; let self_prefix = false; self.complete_type_fields_and_methods( @@ -1364,7 +1464,6 @@ impl<'a> Visitor for NodeFinder<'a> { // Assuming member_access_expression is of the form `foo.bar`, we are right after `bar` let location = Location::new(member_access_expression.lhs.span, self.file); if let Some(typ) = self.interner.type_at_location(location) { - let typ = typ.follow_bindings(); let prefix = ident.to_string().to_case(Case::Snake); let self_prefix = false; self.complete_type_fields_and_methods( @@ -1443,6 +1542,48 @@ impl<'a> Visitor for NodeFinder<'a> { } } +fn get_field_type(typ: &Type, name: &str) -> Option { + match typ { + Type::Struct(struct_type, generics) => { + Some(struct_type.borrow().get_field(name, generics)?.0) + } + Type::Tuple(types) => { + if let Ok(index) = name.parse::() { + types.get(index as usize).cloned() + } else { + None + } + } + Type::Alias(alias_type, generics) => Some(alias_type.borrow().get_type(generics)), + Type::TypeVariable(var, _) | Type::NamedGeneric(var, _, _) => { + if let TypeBinding::Bound(typ) = &*var.borrow() { + get_field_type(typ, name) + } else { + None + } + } + _ => None, + } +} + +fn get_array_element_type(typ: Type) -> Option { + match typ { + Type::Array(_, typ) | Type::Slice(typ) => Some(*typ), + Type::Alias(alias_type, generics) => { + let typ = alias_type.borrow().get_type(&generics); + get_array_element_type(typ) + } + Type::TypeVariable(var, _) | Type::NamedGeneric(var, _, _) => { + if let TypeBinding::Bound(typ) = &*var.borrow() { + get_array_element_type(typ.clone()) + } else { + None + } + } + _ => None, + } +} + /// Returns true if name matches a prefix written in code. /// `prefix` must already be in snake case. /// This method splits both name and prefix by underscore, diff --git a/tooling/lsp/src/requests/completion/tests.rs b/tooling/lsp/src/requests/completion/tests.rs index f019d9980ab..0d1ac6d78ca 100644 --- a/tooling/lsp/src/requests/completion/tests.rs +++ b/tooling/lsp/src/requests/completion/tests.rs @@ -2015,4 +2015,108 @@ mod completion_tests { ) .await; } + + #[test] + async fn test_suggests_when_assignment_follows_in_chain_1() { + let src = r#" + struct Foo { + bar: Bar + } + + struct Bar { + baz: Field + } + + fn f(foo: Foo) { + let mut x = 1; + + foo.bar.>|< + + x = 2; + }"#; + + assert_completion(src, vec![field_completion_item("baz", "Field")]).await; + } + + #[test] + async fn test_suggests_when_assignment_follows_in_chain_2() { + let src = r#" + struct Foo { + bar: Bar + } + + struct Bar { + baz: Baz + } + + struct Baz { + qux: Field + } + + fn f(foo: Foo) { + let mut x = 1; + + foo.bar.baz.>|< + + x = 2; + }"#; + + assert_completion(src, vec![field_completion_item("qux", "Field")]).await; + } + + #[test] + async fn test_suggests_when_assignment_follows_in_chain_3() { + let src = r#" + struct Foo { + foo: Field + } + + fn execute() { + let a = Foo { foo: 1 }; + a.>|< + + x = 1; + }"#; + + assert_completion(src, vec![field_completion_item("foo", "Field")]).await; + } + + #[test] + async fn test_suggests_when_assignment_follows_in_chain_4() { + let src = r#" + struct Foo { + bar: Bar + } + + struct Bar { + baz: Field + } + + fn execute() { + let foo = Foo { foo: 1 }; + foo.bar.>|< + + x = 1; + }"#; + + assert_completion(src, vec![field_completion_item("baz", "Field")]).await; + } + + #[test] + async fn test_suggests_when_assignment_follows_in_chain_with_index() { + let src = r#" + struct Foo { + bar: Field + } + + fn f(foos: [Foo; 3]) { + let mut x = 1; + + foos[0].>|< + + x = 2; + }"#; + + assert_completion(src, vec![field_completion_item("bar", "Field")]).await; + } }