diff --git a/crates/noirc_frontend/src/ast/expression.rs b/crates/noirc_frontend/src/ast/expression.rs index 9be6f715a14..9e9ff2f592e 100644 --- a/crates/noirc_frontend/src/ast/expression.rs +++ b/crates/noirc_frontend/src/ast/expression.rs @@ -106,6 +106,12 @@ impl Recoverable for Expression { } } +impl Recoverable for Option { + fn error(span: Span) -> Self { + Some(Expression::new(ExpressionKind::Error, span)) + } +} + #[derive(Debug, Eq, Clone)] pub struct Expression { pub kind: ExpressionKind, diff --git a/crates/noirc_frontend/src/ast/statement.rs b/crates/noirc_frontend/src/ast/statement.rs index d4fabccea70..2792d51c41c 100644 --- a/crates/noirc_frontend/src/ast/statement.rs +++ b/crates/noirc_frontend/src/ast/statement.rs @@ -51,6 +51,8 @@ impl Statement { last_statement_in_block: bool, emit_error: &mut dyn FnMut(ParserError), ) -> Statement { + let missing_semicolon = + ParserError::with_reason(ParserErrorReason::MissingSeparatingSemi, span); match self { Statement::Let(_) | Statement::Constrain(_) @@ -59,10 +61,7 @@ impl Statement { | Statement::Error => { // To match rust, statements always require a semicolon, even at the end of a block if semi.is_none() { - emit_error(ParserError::with_reason( - ParserErrorReason::MissingSeparatingSemi, - span, - )); + emit_error(missing_semicolon); } self } @@ -85,10 +84,7 @@ impl Statement { // for unneeded expressions like { 1 + 2; 3 } (_, Some(_), false) => Statement::Expression(expr), (_, None, false) => { - emit_error(ParserError::with_reason( - ParserErrorReason::MissingSeparatingSemi, - span, - )); + emit_error(missing_semicolon); Statement::Expression(expr) } diff --git a/crates/noirc_frontend/src/parser/errors.rs b/crates/noirc_frontend/src/parser/errors.rs index d4a294482a8..e788893c58d 100644 --- a/crates/noirc_frontend/src/parser/errors.rs +++ b/crates/noirc_frontend/src/parser/errors.rs @@ -21,6 +21,8 @@ pub enum ParserErrorReason { ConstrainDeprecated, #[error("Expression is invalid in an array-length type: '{0}'. Only unsigned integer constants, globals, generics, +, -, *, /, and % may be used in this context.")] InvalidArrayLengthExpression(Expression), + #[error("Early 'return' is unsupported")] + EarlyReturn, } /// Represents a parsing error, or a parsing error in the making. diff --git a/crates/noirc_frontend/src/parser/parser.rs b/crates/noirc_frontend/src/parser/parser.rs index 98b45247567..2044a02c68e 100644 --- a/crates/noirc_frontend/src/parser/parser.rs +++ b/crates/noirc_frontend/src/parser/parser.rs @@ -449,6 +449,7 @@ where assertion(expr_parser.clone()), declaration(expr_parser.clone()), assignment(expr_parser.clone()), + return_statement(expr_parser.clone()), expr_parser.map(Statement::Expression), )) } @@ -714,6 +715,18 @@ fn expression() -> impl ExprParser { .labelled(ParsingRuleLabel::Expression) } +fn return_statement<'a, P>(expr_parser: P) -> impl NoirParser + 'a +where + P: ExprParser + 'a, +{ + ignore_then_commit(keyword(Keyword::Return), expr_parser.or_not()) + .validate(|_, span, emit| { + emit(ParserError::with_reason(ParserErrorReason::EarlyReturn, span)); + Statement::Error + }) + .labelled(ParsingRuleLabel::Statement) +} + // An expression is a single term followed by 0 or more (OP subexpression)* // where OP is an operator at the given precedence level and subexpression // is an expression at the current precedence level plus one. @@ -1599,4 +1612,40 @@ mod test { ); } } + + #[test] + fn return_validation() { + let cases = vec![ + ("{ return 42; }", 1, "{\n Error\n}"), + ("{ return 1; return 2; }", 2, "{\n Error\n Error\n}"), + ( + "{ return 123; let foo = 4 + 3; }", + 1, + "{\n Error\n let foo: unspecified = (4 + 3)\n}", + ), + ("{ return 1 + 2 }", 2, "{\n Error\n}"), + ("{ return; }", 1, "{\n Error\n}"), + ]; + + let show_errors = |v| vecmap(&v, ToString::to_string).join("\n"); + + let results = vecmap(&cases, |&(src, expected_errors, expected_result)| { + let (opt, errors) = parse_recover(block(expression()), src); + let actual = opt.map(|ast| ast.to_string()); + let actual = if let Some(s) = &actual { s.to_string() } else { "(none)".to_string() }; + + let result = + ((errors.len(), actual.clone()), (expected_errors, expected_result.to_string())); + if result.0 != result.1 { + let num_errors = errors.len(); + let shown_errors = show_errors(errors); + eprintln!( + "\nExpected {} error(s) and got {}:\n\n{}\n\nFrom input: {}\nExpected AST: {}\nActual AST: {}\n", + expected_errors, num_errors, shown_errors, src, expected_result, actual); + } + result + }); + + assert_eq!(vecmap(&results, |t| t.0.clone()), vecmap(&results, |t| t.1.clone()),); + } }