diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index ac4da2892fb..8ce2e1a41c0 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -461,6 +461,18 @@ pub struct PathSegment { pub span: Span, } +impl PathSegment { + /// Returns the span where turbofish happen. For example: + /// + /// foo:: + /// ~^^^^ + /// + /// Returns an empty span at the end of `foo` if there's no turbofish. + pub fn turbofish_span(&self) -> Span { + Span::from(self.ident.span().end()..self.span.end()) + } +} + impl From for PathSegment { fn from(ident: Ident) -> PathSegment { let span = ident.span(); diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 7116ee0ac10..295297cc738 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -429,23 +429,14 @@ impl<'context> Elaborator<'context> { } }; - let struct_generics = if let Some(turbofish_generics) = &last_segment.generics { - if turbofish_generics.len() == struct_generics.len() { - let struct_type = r#type.borrow(); - self.resolve_turbofish_generics(&struct_type.generics, turbofish_generics.clone()) - } else { - self.push_err(TypeCheckError::GenericCountMismatch { - item: format!("struct {}", last_segment.ident), - expected: struct_generics.len(), - found: turbofish_generics.len(), - span: Span::from(last_segment.ident.span().end()..last_segment.span.end()), - }); + let turbofish_span = last_segment.turbofish_span(); - struct_generics - } - } else { - struct_generics - }; + let struct_generics = self.resolve_struct_turbofish_generics( + &r#type.borrow(), + struct_generics, + last_segment.generics, + turbofish_span, + ); let struct_type = r#type.clone(); let generics = struct_generics.clone(); diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index 7aab8d1a24c..ade5420bce4 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -157,8 +157,12 @@ impl<'context> Elaborator<'context> { mutable: Option, new_definitions: &mut Vec, ) -> HirPattern { - let name_span = name.last_ident().span(); - let is_self_type = name.last_ident().is_self_type_name(); + let exclude_last_segment = true; + self.check_unsupported_turbofish_usage(&name, exclude_last_segment); + + let last_segment = name.last_segment(); + let name_span = last_segment.ident.span(); + let is_self_type = last_segment.ident.is_self_type_name(); let error_identifier = |this: &mut Self| { // Must create a name here to return a HirPattern::Identifier. Allowing @@ -178,6 +182,15 @@ impl<'context> Elaborator<'context> { } }; + let turbofish_span = last_segment.turbofish_span(); + + let generics = self.resolve_struct_turbofish_generics( + &struct_type.borrow(), + generics, + last_segment.generics, + turbofish_span, + ); + let actual_type = Type::Struct(struct_type.clone(), generics); let location = Location::new(span, self.file); @@ -426,6 +439,30 @@ impl<'context> Elaborator<'context> { }) } + pub(super) fn resolve_struct_turbofish_generics( + &mut self, + struct_type: &StructType, + generics: Vec, + unresolved_turbofish: Option>, + span: Span, + ) -> Vec { + let Some(turbofish_generics) = unresolved_turbofish else { + return generics; + }; + + if turbofish_generics.len() != generics.len() { + self.push_err(TypeCheckError::GenericCountMismatch { + item: format!("struct {}", struct_type.name), + expected: generics.len(), + found: turbofish_generics.len(), + span, + }); + return generics; + } + + self.resolve_turbofish_generics(&struct_type.generics, turbofish_generics) + } + pub(super) fn resolve_turbofish_generics( &mut self, generics: &[ResolvedGeneric], diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 430967d8a51..ada6a3494a5 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1623,8 +1623,7 @@ impl<'context> Elaborator<'context> { } if segment.generics.is_some() { - // From "foo::", create a span for just "::" - let span = Span::from(segment.ident.span().end()..segment.span.end()); + let span = segment.turbofish_span(); self.push_err(TypeCheckError::UnsupportedTurbofishUsage { span }); } } diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 7c9656e3ec0..a7c62048283 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -557,7 +557,7 @@ fn pattern() -> impl NoirParser { .separated_by(just(Token::Comma)) .delimited_by(just(Token::LeftBrace), just(Token::RightBrace)); - let struct_pattern = path() + let struct_pattern = path(super::parse_type()) .then(struct_pattern_fields) .map_with_span(|(typename, fields), span| Pattern::Struct(typename, fields, span)); @@ -1128,7 +1128,7 @@ fn constructor(expr_parser: impl ExprParser) -> impl NoirParser .allow_trailing() .delimited_by(just(Token::LeftBrace), just(Token::RightBrace)); - path().then(args).map(ExpressionKind::constructor) + path(super::parse_type()).then(args).map(ExpressionKind::constructor) } fn constructor_field

(expr_parser: P) -> impl NoirParser<(Ident, Expression)> diff --git a/compiler/noirc_frontend/src/parser/parser/path.rs b/compiler/noirc_frontend/src/parser/parser/path.rs index 5565c392d59..140650af1a2 100644 --- a/compiler/noirc_frontend/src/parser/parser/path.rs +++ b/compiler/noirc_frontend/src/parser/parser/path.rs @@ -1,4 +1,4 @@ -use crate::ast::{Path, PathKind, PathSegment}; +use crate::ast::{Path, PathKind, PathSegment, UnresolvedType}; use crate::parser::NoirParser; use crate::token::{Keyword, Token}; @@ -8,8 +8,10 @@ use chumsky::prelude::*; use super::keyword; use super::primitives::{path_segment, path_segment_no_turbofish}; -pub(super) fn path() -> impl NoirParser { - path_inner(path_segment()) +pub(super) fn path<'a>( + type_parser: impl NoirParser + 'a, +) -> impl NoirParser + 'a { + path_inner(path_segment(type_parser)) } pub(super) fn path_no_turbofish() -> impl NoirParser { @@ -40,13 +42,16 @@ fn empty_path() -> impl NoirParser { } pub(super) fn maybe_empty_path() -> impl NoirParser { - path().or(empty_path()) + path_no_turbofish().or(empty_path()) } #[cfg(test)] mod test { use super::*; - use crate::parser::parser::test_helpers::{parse_all_failing, parse_with}; + use crate::parser::{ + parse_type, + parser::test_helpers::{parse_all_failing, parse_with}, + }; #[test] fn parse_path() { @@ -59,13 +64,13 @@ mod test { ]; for (src, expected_segments) in cases { - let path: Path = parse_with(path(), src).unwrap(); + let path: Path = parse_with(path(parse_type()), src).unwrap(); for (segment, expected) in path.segments.into_iter().zip(expected_segments) { assert_eq!(segment.ident.0.contents, expected); } } - parse_all_failing(path(), vec!["std::", "::std", "std::hash::", "foo::1"]); + parse_all_failing(path(parse_type()), vec!["std::", "::std", "std::hash::", "foo::1"]); } #[test] @@ -78,12 +83,12 @@ mod test { ]; for (src, expected_path_kind) in cases { - let path = parse_with(path(), src).unwrap(); + let path = parse_with(path(parse_type()), src).unwrap(); assert_eq!(path.kind, expected_path_kind); } parse_all_failing( - path(), + path(parse_type()), vec!["crate", "crate::std::crate", "foo::bar::crate", "foo::dep"], ); } diff --git a/compiler/noirc_frontend/src/parser/parser/primitives.rs b/compiler/noirc_frontend/src/parser/parser/primitives.rs index eb8d67b751a..25f693bf504 100644 --- a/compiler/noirc_frontend/src/parser/parser/primitives.rs +++ b/compiler/noirc_frontend/src/parser/parser/primitives.rs @@ -33,10 +33,14 @@ pub(super) fn token_kind(token_kind: TokenKind) -> impl NoirParser { }) } -pub(super) fn path_segment() -> impl NoirParser { - ident() - .then(turbofish(super::parse_type())) - .map_with_span(|(ident, generics), span| PathSegment { ident, generics, span }) +pub(super) fn path_segment<'a>( + type_parser: impl NoirParser + 'a, +) -> impl NoirParser + 'a { + ident().then(turbofish(type_parser)).map_with_span(|(ident, generics), span| PathSegment { + ident, + generics, + span, + }) } pub(super) fn path_segment_no_turbofish() -> impl NoirParser { @@ -96,7 +100,7 @@ pub(super) fn turbofish<'a>( } pub(super) fn variable() -> impl NoirParser { - path().map(ExpressionKind::Variable) + path(super::parse_type()).map(ExpressionKind::Variable) } pub(super) fn variable_no_turbofish() -> impl NoirParser { diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index f2b83a48022..c23870bbb43 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2610,3 +2610,74 @@ fn turbofish_in_middle_of_variable_unsupported_yet() { CompilationError::TypeError(TypeCheckError::UnsupportedTurbofishUsage { .. }), )); } + +#[test] +fn turbofish_in_struct_pattern() { + let src = r#" + struct Foo { + x: T + } + + fn main() { + let value: Field = 0; + let Foo:: { x } = Foo { x: value }; + let _ = x; + } + "#; + assert_no_errors(src); +} + +#[test] +fn turbofish_in_struct_pattern_errors_if_type_mismatch() { + let src = r#" + struct Foo { + x: T + } + + fn main() { + let value: Field = 0; + let Foo:: { x } = Foo { x: value }; + let _ = x; + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::TypeMismatchWithSource { .. }) = &errors[0].0 + else { + panic!("Expected a type mismatch error, got {:?}", errors[0].0); + }; +} + +#[test] +fn turbofish_in_struct_pattern_generic_count_mismatch() { + let src = r#" + struct Foo { + x: T + } + + fn main() { + let value = 0; + let Foo:: { x } = Foo { x: value }; + let _ = x; + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::GenericCountMismatch { + item, + expected, + found, + .. + }) = &errors[0].0 + else { + panic!("Expected a generic count mismatch error, got {:?}", errors[0].0); + }; + + assert_eq!(item, "struct Foo"); + assert_eq!(*expected, 1); + assert_eq!(*found, 2); +}