From 0e266ebc7328866b0b10554e37c9d9012a7b501c Mon Sep 17 00:00:00 2001 From: jfecher Date: Thu, 5 Oct 2023 04:05:46 -0700 Subject: [PATCH] fix!: Make for loops a statement (#2975) --- compiler/noirc_frontend/src/ast/expression.rs | 20 -- compiler/noirc_frontend/src/ast/statement.rs | 26 ++- .../src/hir/resolution/resolver.rs | 85 ++++---- .../noirc_frontend/src/hir/type_check/expr.rs | 36 +--- .../noirc_frontend/src/hir/type_check/mod.rs | 4 +- .../noirc_frontend/src/hir/type_check/stmt.rs | 38 +++- compiler/noirc_frontend/src/hir_def/expr.rs | 9 - compiler/noirc_frontend/src/hir_def/stmt.rs | 9 + .../src/monomorphization/mod.rs | 43 ++-- compiler/noirc_frontend/src/parser/mod.rs | 21 +- compiler/noirc_frontend/src/parser/parser.rs | 184 +++++++++++------- 11 files changed, 245 insertions(+), 230 deletions(-) diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 3febee7f527..06bbddb9744 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -22,7 +22,6 @@ pub enum ExpressionKind { MemberAccess(Box), Cast(Box), Infix(Box), - For(Box), If(Box), Variable(Path), Tuple(Vec), @@ -181,14 +180,6 @@ impl Expression { } } -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct ForExpression { - pub identifier: Ident, - pub start_range: Expression, - pub end_range: Expression, - pub block: Expression, -} - pub type BinaryOp = Spanned; #[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Copy, Clone)] @@ -469,7 +460,6 @@ impl Display for ExpressionKind { MethodCall(call) => call.fmt(f), Cast(cast) => cast.fmt(f), Infix(infix) => infix.fmt(f), - For(for_loop) => for_loop.fmt(f), If(if_expr) => if_expr.fmt(f), Variable(path) => path.fmt(f), Constructor(constructor) => constructor.fmt(f), @@ -603,16 +593,6 @@ impl Display for BinaryOpKind { } } -impl Display for ForExpression { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "for {} in {} .. {} {}", - self.identifier, self.start_range, self.end_range, self.block - ) - } -} - impl Display for IfExpression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "if {} {}", self.condition, self.consequence)?; diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 3d9ab1e6ec4..a1834a8b18a 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -23,6 +23,7 @@ pub enum Statement { Constrain(ConstrainStatement), Expression(Expression), Assign(AssignStatement), + For(ForLoopStatement), // This is an expression with a trailing semi-colon Semi(Expression), // This statement is the result of a recovered parse error. @@ -65,13 +66,13 @@ impl Statement { } self } + // A semicolon on a for loop is optional and does nothing + Statement::For(_) => self, Statement::Expression(expr) => { match (&expr.kind, semi, last_statement_in_block) { // Semicolons are optional for these expressions - (ExpressionKind::Block(_), semi, _) - | (ExpressionKind::For(_), semi, _) - | (ExpressionKind::If(_), semi, _) => { + (ExpressionKind::Block(_), semi, _) | (ExpressionKind::If(_), semi, _) => { if semi.is_some() { Statement::Semi(expr) } else { @@ -459,6 +460,14 @@ impl LValue { } } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct ForLoopStatement { + pub identifier: Ident, + pub start_range: Expression, + pub end_range: Expression, + pub block: Expression, +} + impl Display for Statement { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -466,6 +475,7 @@ impl Display for Statement { Statement::Constrain(constrain) => constrain.fmt(f), Statement::Expression(expression) => expression.fmt(f), Statement::Assign(assign) => assign.fmt(f), + Statement::For(for_loop) => for_loop.fmt(f), Statement::Semi(semi) => write!(f, "{semi};"), Statement::Error => write!(f, "Error"), } @@ -544,3 +554,13 @@ impl Display for Pattern { } } } + +impl Display for ForLoopStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "for {} in {} .. {} {}", + self.identifier, self.start_range, self.end_range, self.block + ) + } +} diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index ef66ba5e032..cd7bb97cda2 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -13,9 +13,9 @@ // XXX: Resolver does not check for unused functions use crate::hir_def::expr::{ HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar, - HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent, - HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, - HirMemberAccess, HirMethodCallExpression, HirPrefixExpression, + HirCastExpression, HirConstructorExpression, HirExpression, HirIdent, HirIfExpression, + HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess, + HirMethodCallExpression, HirPrefixExpression, }; use crate::hir_def::traits::{Trait, TraitConstraint}; @@ -26,7 +26,7 @@ use std::rc::Rc; use crate::graph::CrateId; use crate::hir::def_map::{LocalModuleId, ModuleDefId, TryFromModuleDefId, MAIN_FUNCTION}; -use crate::hir_def::stmt::{HirAssignStatement, HirLValue, HirPattern}; +use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern}; use crate::node_interner::{ DefinitionId, DefinitionKind, ExprId, FuncId, NodeInterner, StmtId, StructId, TraitId, }; @@ -957,6 +957,25 @@ impl<'a> Resolver<'a> { let stmt = HirAssignStatement { lvalue: identifier, expression }; HirStatement::Assign(stmt) } + Statement::For(for_loop) => { + let start_range = self.resolve_expression(for_loop.start_range); + let end_range = self.resolve_expression(for_loop.end_range); + let (identifier, block) = (for_loop.identifier, for_loop.block); + + // TODO: For loop variables are currently mutable by default since we haven't + // yet implemented syntax for them to be optionally mutable. + let (identifier, block) = self.in_new_scope(|this| { + let decl = this.add_variable_decl( + identifier, + false, + true, + DefinitionKind::Local(None), + ); + (decl, this.resolve_expression(block)) + }); + + HirStatement::For(HirForStatement { start_range, end_range, block, identifier }) + } Statement::Error => HirStatement::Error, } } @@ -1169,30 +1188,6 @@ impl<'a> Resolver<'a> { lhs: self.resolve_expression(cast_expr.lhs), r#type: self.resolve_type(cast_expr.r#type), }), - ExpressionKind::For(for_expr) => { - let start_range = self.resolve_expression(for_expr.start_range); - let end_range = self.resolve_expression(for_expr.end_range); - let (identifier, block) = (for_expr.identifier, for_expr.block); - - // TODO: For loop variables are currently mutable by default since we haven't - // yet implemented syntax for them to be optionally mutable. - let (identifier, block_id) = self.in_new_scope(|this| { - let decl = this.add_variable_decl( - identifier, - false, - true, - DefinitionKind::Local(None), - ); - (decl, this.resolve_expression(block)) - }); - - HirExpression::For(HirForExpression { - start_range, - end_range, - block: block_id, - identifier, - }) - } ExpressionKind::If(if_expr) => HirExpression::If(HirIfExpression { condition: self.resolve_expression(if_expr.condition), consequence: self.resolve_expression(if_expr.consequence), @@ -1738,7 +1733,7 @@ mod test { let (hir_func, _, _) = resolver.resolve_function(func, id); // Iterate over function statements and apply filtering function - parse_statement_blocks( + find_lambda_captures( hir_func.block(&interner).statements(), &interner, &mut all_captures, @@ -1747,33 +1742,23 @@ mod test { all_captures } - fn parse_statement_blocks( + fn find_lambda_captures( stmts: &[StmtId], interner: &NodeInterner, result: &mut Vec>, ) { - let mut expr: HirExpression; - for stmt_id in stmts.iter() { let hir_stmt = interner.statement(stmt_id); - match hir_stmt { - HirStatement::Expression(expr_id) => { - expr = interner.expression(&expr_id); - } - HirStatement::Let(let_stmt) => { - expr = interner.expression(&let_stmt.expression); - } - HirStatement::Assign(assign_stmt) => { - expr = interner.expression(&assign_stmt.expression); - } - HirStatement::Constrain(constr_stmt) => { - expr = interner.expression(&constr_stmt.0); - } - HirStatement::Semi(semi_expr) => { - expr = interner.expression(&semi_expr); - } + let expr_id = match hir_stmt { + HirStatement::Expression(expr_id) => expr_id, + HirStatement::Let(let_stmt) => let_stmt.expression, + HirStatement::Assign(assign_stmt) => assign_stmt.expression, + HirStatement::Constrain(constr_stmt) => constr_stmt.0, + HirStatement::Semi(semi_expr) => semi_expr, + HirStatement::For(for_loop) => for_loop.block, HirStatement::Error => panic!("Invalid HirStatement!"), - } + }; + let expr = interner.expression(&expr_id); get_lambda_captures(expr, interner, result); // TODO: dyn filter function as parameter } } @@ -1794,7 +1779,7 @@ mod test { // Check for other captures recursively within the lambda body let hir_body_expr = interner.expression(&lambda_expr.body); if let HirExpression::Block(block_expr) = hir_body_expr { - parse_statement_blocks(block_expr.statements(), interner, result); + find_lambda_captures(block_expr.statements(), interner, result); } } } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index c802482d9e0..7eb4dd0184f 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -11,7 +11,7 @@ use crate::{ types::Type, }, node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId}, - Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp, + Signedness, TypeBinding, TypeVariableKind, UnaryOp, }; use super::{errors::TypeCheckError, TypeChecker}; @@ -194,40 +194,6 @@ impl<'interner> TypeChecker<'interner> { let span = self.interner.expr_span(expr_id); self.check_cast(lhs_type, cast_expr.r#type, span) } - HirExpression::For(for_expr) => { - let start_range_type = self.check_expression(&for_expr.start_range); - let end_range_type = self.check_expression(&for_expr.end_range); - - let start_span = self.interner.expr_span(&for_expr.start_range); - let end_span = self.interner.expr_span(&for_expr.end_range); - - // Check that start range and end range have the same types - let range_span = start_span.merge(end_span); - self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch { - expected_typ: start_range_type.to_string(), - expr_typ: end_range_type.to_string(), - expr_span: range_span, - }); - - let fresh_id = self.interner.next_type_variable_id(); - let type_variable = Shared::new(TypeBinding::Unbound(fresh_id)); - let expected_type = - Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField); - - self.unify(&start_range_type, &expected_type, || { - TypeCheckError::TypeCannotBeUsed { - typ: start_range_type.clone(), - place: "for loop", - span: range_span, - } - .add_context("The range of a loop must be known at compile-time") - }); - - self.interner.push_definition_type(for_expr.identifier.id, start_range_type); - - self.check_expression(&for_expr.block); - Type::Unit - } HirExpression::Block(block_expr) => { let mut block_type = Type::Unit; diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index c2afa44c495..4c3ecce3ede 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -325,11 +325,11 @@ mod test { fn basic_for_expr() { let src = r#" fn main(_x : Field) { - let _j = for _i in 0..10 { + for _i in 0..10 { for _k in 0..100 { } - }; + } } "#; diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index 11f106dab10..6993476e249 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -3,10 +3,12 @@ use noirc_errors::{Location, Span}; use crate::hir_def::expr::{HirExpression, HirIdent, HirLiteral}; use crate::hir_def::stmt::{ - HirAssignStatement, HirConstrainStatement, HirLValue, HirLetStatement, HirPattern, HirStatement, + HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement, + HirPattern, HirStatement, }; use crate::hir_def::types::Type; use crate::node_interner::{DefinitionId, ExprId, StmtId}; +use crate::{Shared, TypeBinding, TypeVariableKind}; use super::errors::{Source, TypeCheckError}; use super::TypeChecker; @@ -48,11 +50,45 @@ impl<'interner> TypeChecker<'interner> { HirStatement::Let(let_stmt) => self.check_let_stmt(let_stmt), HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt), HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id), + HirStatement::For(for_loop) => self.check_for_loop(for_loop), HirStatement::Error => (), } Type::Unit } + fn check_for_loop(&mut self, for_loop: HirForStatement) { + let start_range_type = self.check_expression(&for_loop.start_range); + let end_range_type = self.check_expression(&for_loop.end_range); + + let start_span = self.interner.expr_span(&for_loop.start_range); + let end_span = self.interner.expr_span(&for_loop.end_range); + + // Check that start range and end range have the same types + let range_span = start_span.merge(end_span); + self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch { + expected_typ: start_range_type.to_string(), + expr_typ: end_range_type.to_string(), + expr_span: range_span, + }); + + let fresh_id = self.interner.next_type_variable_id(); + let type_variable = Shared::new(TypeBinding::Unbound(fresh_id)); + let expected_type = Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField); + + self.unify(&start_range_type, &expected_type, || { + TypeCheckError::TypeCannotBeUsed { + typ: start_range_type.clone(), + place: "for loop", + span: range_span, + } + .add_context("The range of a loop must be known at compile-time") + }); + + self.interner.push_definition_type(for_loop.identifier.id, start_range_type); + + self.check_expression(&for_loop.block); + } + /// Associate a given HirPattern with the given Type, and remember /// this association in the NodeInterner. pub(crate) fn bind_pattern(&mut self, pattern: &HirPattern, typ: Type) { diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 8ec106c8c37..15d4b12d30b 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -26,7 +26,6 @@ pub enum HirExpression { Call(HirCallExpression), MethodCall(HirMethodCallExpression), Cast(HirCastExpression), - For(HirForExpression), If(HirIfExpression), Tuple(Vec), Lambda(HirLambda), @@ -48,14 +47,6 @@ pub struct HirIdent { pub id: DefinitionId, } -#[derive(Debug, Clone)] -pub struct HirForExpression { - pub identifier: HirIdent, - pub start_range: ExprId, - pub end_range: ExprId, - pub block: ExprId, -} - #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct HirBinaryOp { pub kind: BinaryOpKind, diff --git a/compiler/noirc_frontend/src/hir_def/stmt.rs b/compiler/noirc_frontend/src/hir_def/stmt.rs index 76601d893ed..21f9b431b3a 100644 --- a/compiler/noirc_frontend/src/hir_def/stmt.rs +++ b/compiler/noirc_frontend/src/hir_def/stmt.rs @@ -13,6 +13,7 @@ pub enum HirStatement { Let(HirLetStatement), Constrain(HirConstrainStatement), Assign(HirAssignStatement), + For(HirForStatement), Expression(ExprId), Semi(ExprId), Error, @@ -34,6 +35,14 @@ impl HirLetStatement { } } +#[derive(Debug, Clone)] +pub struct HirForStatement { + pub identifier: HirIdent, + pub start_range: ExprId, + pub end_range: ExprId, + pub block: ExprId, +} + /// Corresponds to `lvalue = expression;` in the source code #[derive(Debug, Clone)] pub struct HirAssignStatement { diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 2af0ac433d1..c0a0002b3c1 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -341,28 +341,6 @@ impl<'interner> Monomorphizer<'interner> { location: self.interner.expr_location(&expr), }), - HirExpression::For(for_expr) => { - self.is_range_loop = true; - let start = self.expr(for_expr.start_range); - let end = self.expr(for_expr.end_range); - self.is_range_loop = false; - let index_variable = self.next_local_id(); - self.define_local(for_expr.identifier.id, index_variable); - - let block = Box::new(self.expr(for_expr.block)); - - ast::Expression::For(ast::For { - index_variable, - index_name: self.interner.definition_name(for_expr.identifier.id).to_owned(), - index_type: self.convert_type(&self.interner.id_type(for_expr.start_range)), - start_range: Box::new(start), - end_range: Box::new(end), - start_range_location: self.interner.expr_location(&for_expr.start_range), - end_range_location: self.interner.expr_location(&for_expr.end_range), - block, - }) - } - HirExpression::If(if_expr) => { let cond = self.expr(if_expr.condition); let then = self.expr(if_expr.consequence); @@ -445,6 +423,27 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Constrain(Box::new(expr), location, constrain.2) } HirStatement::Assign(assign) => self.assign(assign), + HirStatement::For(for_loop) => { + self.is_range_loop = true; + let start = self.expr(for_loop.start_range); + let end = self.expr(for_loop.end_range); + self.is_range_loop = false; + let index_variable = self.next_local_id(); + self.define_local(for_loop.identifier.id, index_variable); + + let block = Box::new(self.expr(for_loop.block)); + + ast::Expression::For(ast::For { + index_variable, + index_name: self.interner.definition_name(for_loop.identifier.id).to_owned(), + index_type: self.convert_type(&self.interner.id_type(for_loop.start_range)), + start_range: Box::new(start), + end_range: Box::new(end), + start_range_location: self.interner.expr_location(&for_loop.start_range), + end_range_location: self.interner.expr_location(&for_loop.end_range), + block, + }) + } HirStatement::Expression(expr) => self.expr(expr), HirStatement::Semi(expr) => ast::Expression::Semi(Box::new(self.expr(expr))), HirStatement::Error => unreachable!(), diff --git a/compiler/noirc_frontend/src/parser/mod.rs b/compiler/noirc_frontend/src/parser/mod.rs index efd85861235..8fc882068eb 100644 --- a/compiler/noirc_frontend/src/parser/mod.rs +++ b/compiler/noirc_frontend/src/parser/mod.rs @@ -16,7 +16,7 @@ use std::sync::atomic::{AtomicU32, Ordering}; use crate::token::{Keyword, Token}; use crate::{ast::ImportStatement, Expression, NoirStruct}; use crate::{ - BlockExpression, ExpressionKind, ForExpression, Ident, IndexExpression, LetStatement, + BlockExpression, ExpressionKind, ForLoopStatement, Ident, IndexExpression, LetStatement, MethodCallExpression, NoirFunction, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, PathKind, Pattern, Recoverable, Statement, TypeImpl, UnresolvedType, UseTree, }; @@ -380,15 +380,10 @@ impl ForRange { /// ... /// } /// } - fn into_for(self, identifier: Ident, block: Expression, for_loop_span: Span) -> ExpressionKind { + fn into_for(self, identifier: Ident, block: Expression, for_loop_span: Span) -> Statement { match self { ForRange::Range(start_range, end_range) => { - ExpressionKind::For(Box::new(ForExpression { - identifier, - start_range, - end_range, - block, - })) + Statement::For(ForLoopStatement { identifier, start_range, end_range, block }) } ForRange::Array(array) => { let array_span = array.span; @@ -443,17 +438,15 @@ impl ForRange { let block_span = block.span; let new_block = BlockExpression(vec![let_elem, Statement::Expression(block)]); let new_block = Expression::new(ExpressionKind::Block(new_block), block_span); - let for_loop = ExpressionKind::For(Box::new(ForExpression { + let for_loop = Statement::For(ForLoopStatement { identifier: fresh_identifier, start_range, end_range, block: new_block, - })); + }); - ExpressionKind::Block(BlockExpression(vec![ - let_array, - Statement::Expression(Expression::new(for_loop, for_loop_span)), - ])) + let block = ExpressionKind::Block(BlockExpression(vec![let_array, for_loop])); + Statement::Expression(Expression::new(block, for_loop_span)) } } } diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index f428efd6af7..7723c9f3690 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -166,7 +166,7 @@ fn function_definition(allow_self: bool) -> impl NoirParser { .then(parenthesized(function_parameters(allow_self))) .then(function_return_type()) .then(where_clause()) - .then(spanned(block(expression()))) + .then(spanned(block(fresh_statement()))) .validate(|(((args, ret), where_clause), (body, body_span)), span, emit| { let ((((attributes, modifiers), name), generics), parameters) = args; @@ -414,7 +414,7 @@ fn trait_constant_declaration() -> impl NoirParser { /// trait_function_declaration: 'fn' ident generics '(' declaration_parameters ')' function_return_type fn trait_function_declaration() -> impl NoirParser { let trait_function_body_or_semicolon = - block(expression()).map(Option::from).or(just(Token::Semicolon).map(|_| Option::None)); + block(fresh_statement()).map(Option::from).or(just(Token::Semicolon).map(|_| Option::None)); keyword(Keyword::Fn) .ignore_then(ident()) @@ -638,19 +638,13 @@ fn trait_bound() -> impl NoirParser { }) } -fn block_expr<'a, P>(expr_parser: P) -> impl NoirParser + 'a -where - P: ExprParser + 'a, -{ - block(expr_parser).map(ExpressionKind::Block).map_with_span(Expression::new) +fn block_expr<'a>(statement: impl NoirParser + 'a) -> impl NoirParser + 'a { + block(statement).map(ExpressionKind::Block).map_with_span(Expression::new) } -fn block<'a, P>(expr_parser: P) -> impl NoirParser + 'a -where - P: ExprParser + 'a, -{ +fn block<'a>(statement: impl NoirParser + 'a) -> impl NoirParser + 'a { use Token::*; - statement(expr_parser) + statement .recover_via(statement_recovery()) .then(just(Semicolon).or_not().map_with_span(|s, span| (s, span))) .repeated() @@ -761,19 +755,27 @@ fn ident() -> impl NoirParser { token_kind(TokenKind::Ident).map_with_span(Ident::from_token) } -fn statement<'a, P>(expr_parser: P) -> impl NoirParser + 'a +fn statement<'a, P, P2>(expr_parser: P, expr_no_constructors: P2) -> impl NoirParser + 'a where P: ExprParser + 'a, + P2: ExprParser + 'a, { - choice(( - constrain(expr_parser.clone()), - assertion(expr_parser.clone()), - assertion_eq(expr_parser.clone()), - declaration(expr_parser.clone()), - assignment(expr_parser.clone()), - return_statement(expr_parser.clone()), - expr_parser.map(Statement::Expression), - )) + recursive(|statement| { + choice(( + constrain(expr_parser.clone()), + assertion(expr_parser.clone()), + assertion_eq(expr_parser.clone()), + declaration(expr_parser.clone()), + assignment(expr_parser.clone()), + for_loop(expr_no_constructors, statement), + return_statement(expr_parser.clone()), + expr_parser.map(Statement::Expression), + )) + }) +} + +fn fresh_statement() -> impl NoirParser { + statement(expression(), expression_no_constructors()) } fn constrain<'a, P>(expr_parser: P) -> impl NoirParser + 'a @@ -1099,6 +1101,7 @@ fn type_expression() -> impl NoirParser { Precedence::lowest_type_precedence(), expr, nothing(), + nothing(), true, false, ) @@ -1159,8 +1162,9 @@ fn expression() -> impl ExprParser { recursive(|expr| { expression_with_precedence( Precedence::Lowest, - expr, + expr.clone(), expression_no_constructors(), + statement(expr, expression_no_constructors()), false, true, ) @@ -1170,7 +1174,14 @@ fn expression() -> impl ExprParser { fn expression_no_constructors() -> impl ExprParser { recursive(|expr| { - expression_with_precedence(Precedence::Lowest, expr.clone(), expr, false, false) + expression_with_precedence( + Precedence::Lowest, + expr.clone(), + expr.clone(), + statement(expr.clone(), expr), + false, + false, + ) }) .labelled(ParsingRuleLabel::Expression) } @@ -1190,10 +1201,11 @@ where // An expression is a single term followed by 0 or more (OP subexpression)* // where OP is an operator at the given precedence level and subexpression // is an expression at the current precedence level plus one. -fn expression_with_precedence<'a, P, P2>( +fn expression_with_precedence<'a, P, P2, S>( precedence: Precedence, expr_parser: P, expr_no_constructors: P2, + statement: S, // True if we should only parse the restricted subset of operators valid within type expressions is_type_expression: bool, // True if we should also parse constructors `Foo { field1: value1, ... }` as an expression. @@ -1204,12 +1216,13 @@ fn expression_with_precedence<'a, P, P2>( where P: ExprParser + 'a, P2: ExprParser + 'a, + S: NoirParser + 'a, { if precedence == Precedence::Highest { if is_type_expression { type_expression_term(expr_parser).boxed().labelled(ParsingRuleLabel::Term) } else { - term(expr_parser, expr_no_constructors, allow_constructors) + term(expr_parser, expr_no_constructors, statement, allow_constructors) .boxed() .labelled(ParsingRuleLabel::Term) } @@ -1221,6 +1234,7 @@ where next_precedence, expr_parser, expr_no_constructors, + statement, is_type_expression, allow_constructors, ); @@ -1260,14 +1274,16 @@ fn operator_with_precedence(precedence: Precedence) -> impl NoirParser( +fn term<'a, P, P2, S>( expr_parser: P, expr_no_constructors: P2, + statement: S, allow_constructors: bool, ) -> impl NoirParser + 'a where P: ExprParser + 'a, P2: ExprParser + 'a, + S: NoirParser + 'a, { recursive(move |term_parser| { choice(( @@ -1280,7 +1296,12 @@ where // right-unary operators like a[0] or a.f bind more tightly than left-unary // operators like - or !, so that !a[0] is parsed as !(a[0]). This is a bit // awkward for casts so -a as i32 actually binds as -(a as i32). - .or(atom_or_right_unary(expr_parser, expr_no_constructors, allow_constructors)) + .or(atom_or_right_unary( + expr_parser, + expr_no_constructors, + statement, + allow_constructors, + )) }) } @@ -1295,14 +1316,16 @@ where }) } -fn atom_or_right_unary<'a, P, P2>( +fn atom_or_right_unary<'a, P, P2, S>( expr_parser: P, expr_no_constructors: P2, + statement: S, allow_constructors: bool, ) -> impl NoirParser + 'a where P: ExprParser + 'a, P2: ExprParser + 'a, + S: NoirParser + 'a, { enum UnaryRhs { Call(Vec), @@ -1336,7 +1359,7 @@ where let rhs = choice((call_rhs, array_rhs, cast_rhs, member_rhs)); foldl_with_span( - atom(expr_parser, expr_no_constructors, allow_constructors), + atom(expr_parser, expr_no_constructors, statement, allow_constructors), rhs, |lhs, rhs, span| match rhs { UnaryRhs::Call(args) => Expression::call(lhs, args, span), @@ -1349,25 +1372,21 @@ where ) } -fn if_expr<'a, P1, P2>( - expr_parser: P1, - expr_no_constructors: P2, -) -> impl NoirParser + 'a +fn if_expr<'a, P, S>(expr_no_constructors: P, statement: S) -> impl NoirParser + 'a where - P1: ExprParser + 'a, - P2: ExprParser + 'a, + P: ExprParser + 'a, + S: NoirParser + 'a, { recursive(|if_parser| { - let if_block = block_expr(expr_parser.clone()); + let if_block = block_expr(statement.clone()); // The else block could also be an `else if` block, in which case we must recursively parse it. - let else_block = - block_expr(expr_parser.clone()).or(if_parser.map_with_span(|kind, span| { - // Wrap the inner `if` expression in a block expression. - // i.e. rewrite the sugared form `if cond1 {} else if cond2 {}` as `if cond1 {} else { if cond2 {} }`. - let if_expression = Expression::new(kind, span); - let desugared_else = BlockExpression(vec![Statement::Expression(if_expression)]); - Expression::new(ExpressionKind::Block(desugared_else), span) - })); + let else_block = block_expr(statement).or(if_parser.map_with_span(|kind, span| { + // Wrap the inner `if` expression in a block expression. + // i.e. rewrite the sugared form `if cond1 {} else if cond2 {}` as `if cond1 {} else { if cond2 {} }`. + let if_expression = Expression::new(kind, span); + let desugared_else = BlockExpression(vec![Statement::Expression(if_expression)]); + Expression::new(ExpressionKind::Block(desugared_else), span) + })); keyword(Keyword::If) .ignore_then(expr_no_constructors) @@ -1391,19 +1410,16 @@ fn lambda<'a>( }) } -fn for_expr<'a, P, P2>( - expr_parser: P, - expr_no_constructors: P2, -) -> impl NoirParser + 'a +fn for_loop<'a, P, S>(expr_no_constructors: P, statement: S) -> impl NoirParser + 'a where P: ExprParser + 'a, - P2: ExprParser + 'a, + S: NoirParser + 'a, { keyword(Keyword::For) .ignore_then(ident()) .then_ignore(keyword(Keyword::In)) .then(for_range(expr_no_constructors)) - .then(block_expr(expr_parser)) + .then(block_expr(statement)) .map_with_span(|((identifier, range), block), span| range.into_for(identifier, block, span)) } @@ -1494,18 +1510,19 @@ where /// Atoms are parameterized on whether constructor expressions are allowed or not. /// Certain constructs like `if` and `for` disallow constructor expressions when a /// block may be expected. -fn atom<'a, P, P2>( +fn atom<'a, P, P2, S>( expr_parser: P, expr_no_constructors: P2, + statement: S, allow_constructors: bool, ) -> impl NoirParser + 'a where P: ExprParser + 'a, P2: ExprParser + 'a, + S: NoirParser + 'a, { choice(( - if_expr(expr_parser.clone(), expr_no_constructors.clone()), - for_expr(expr_parser.clone(), expr_no_constructors), + if_expr(expr_no_constructors, statement.clone()), array_expr(expr_parser.clone()), if allow_constructors { constructor(expr_parser.clone()).boxed() @@ -1513,7 +1530,7 @@ where nothing().boxed() }, lambda(expr_parser.clone()), - block(expr_parser.clone()).map(ExpressionKind::Block), + block(statement).map(ExpressionKind::Block), variable(), literal(), )) @@ -1722,11 +1739,21 @@ mod test { #[test] fn parse_cast() { parse_all( - atom_or_right_unary(expression(), expression_no_constructors(), true), + atom_or_right_unary( + expression(), + expression_no_constructors(), + fresh_statement(), + true, + ), vec!["x as u8", "0 as Field", "(x + 3) as [Field; 8]"], ); parse_all_failing( - atom_or_right_unary(expression(), expression_no_constructors(), true), + atom_or_right_unary( + expression(), + expression_no_constructors(), + fresh_statement(), + true, + ), vec!["x as pub u8"], ); } @@ -1740,7 +1767,15 @@ mod test { "baz[bar]", "foo.bar[3] as Field .baz as u32 [7]", ]; - parse_all(atom_or_right_unary(expression(), expression_no_constructors(), true), valid); + parse_all( + atom_or_right_unary( + expression(), + expression_no_constructors(), + fresh_statement(), + true, + ), + valid, + ); } fn expr_to_array(expr: ExpressionKind) -> ArrayLiteral { @@ -1794,10 +1829,11 @@ mod test { #[test] fn parse_block() { - parse_with(block(expression()), "{ [0,1,2,3,4] }").unwrap(); + parse_with(block(fresh_statement()), "{ [0,1,2,3,4] }").unwrap(); // Regression for #1310: this should be parsed as a block and not a function call - let res = parse_with(block(expression()), "{ if true { 1 } else { 2 } (3, 4) }").unwrap(); + let res = + parse_with(block(fresh_statement()), "{ if true { 1 } else { 2 } (3, 4) }").unwrap(); match unwrap_expr(res.0.last().unwrap()) { // The `if` followed by a tuple is currently creates a block around both in case // there was none to start with, so there is an extra block here. @@ -1810,7 +1846,7 @@ mod test { } parse_all_failing( - block(expression()), + block(fresh_statement()), vec![ "[0,1,2,3,4] }", "{ [0,1,2,3,4]", @@ -1967,18 +2003,18 @@ mod test { #[test] fn parse_invalid_pub() { // pub cannot be used to declare a statement - parse_all_failing(statement(expression()), vec!["pub x = y", "pub x : pub Field = y"]); + parse_all_failing(fresh_statement(), vec!["pub x = y", "pub x : pub Field = y"]); } #[test] fn parse_for_loop() { parse_all( - for_expr(expression(), expression_no_constructors()), + for_loop(expression_no_constructors(), fresh_statement()), vec!["for i in x+y..z {}", "for i in 0..100 { foo; bar }"], ); parse_all_failing( - for_expr(expression(), expression_no_constructors()), + for_loop(expression_no_constructors(), fresh_statement()), vec![ "for 1 in x+y..z {}", // Cannot have a literal as the loop identifier "for i in 0...100 {}", // Only '..' is supported, there are no inclusive ranges yet @@ -2073,11 +2109,11 @@ mod test { #[test] fn parse_parenthesized_expression() { parse_all( - atom(expression(), expression_no_constructors(), true), + atom(expression(), expression_no_constructors(), fresh_statement(), true), vec!["(0)", "(x+a)", "({(({{({(nested)})}}))})"], ); parse_all_failing( - atom(expression(), expression_no_constructors(), true), + atom(expression(), expression_no_constructors(), fresh_statement(), true), vec!["(x+a", "((x+a)", "(,)"], ); } @@ -2090,12 +2126,12 @@ mod test { #[test] fn parse_if_expr() { parse_all( - if_expr(expression(), expression_no_constructors()), + if_expr(expression_no_constructors(), fresh_statement()), vec!["if x + a { } else { }", "if x {}", "if x {} else if y {} else {}"], ); parse_all_failing( - if_expr(expression(), expression_no_constructors()), + if_expr(expression_no_constructors(), fresh_statement()), vec!["if (x / a) + 1 {} else", "if foo then 1 else 2", "if true { 1 }else 3"], ); } @@ -2189,11 +2225,11 @@ mod test { #[test] fn parse_unary() { parse_all( - term(expression(), expression_no_constructors(), true), + term(expression(), expression_no_constructors(), fresh_statement(), true), vec!["!hello", "-hello", "--hello", "-!hello", "!-hello"], ); parse_all_failing( - term(expression(), expression_no_constructors(), true), + term(expression(), expression_no_constructors(), fresh_statement(), true), vec!["+hello", "/hello"], ); } @@ -2290,7 +2326,7 @@ mod test { "{ expr1; expr2 }", "{ expr1; expr2; }", ]; - parse_all(block(expression()), cases); + parse_all(block(fresh_statement()), cases); let failing = vec![ // We disallow multiple semicolons after a statement unlike rust where it is a warning @@ -2299,7 +2335,7 @@ mod test { "{ let x = 2 }", "{ expr1 expr2 }", ]; - parse_all_failing(block(expression()), failing); + parse_all_failing(block(fresh_statement()), failing); } #[test] @@ -2324,7 +2360,7 @@ mod test { let show_errors = |v| vecmap(v, ToString::to_string).join("\n"); for (src, expected_errors, expected_result) in cases { - let (opt, errors) = parse_recover(statement(expression()), src); + let (opt, errors) = parse_recover(fresh_statement(), src); let actual = opt.map(|ast| ast.to_string()); let actual = if let Some(s) = &actual { s } else { "(none)" }; @@ -2352,7 +2388,7 @@ mod test { let show_errors = |v| vecmap(&v, ToString::to_string).join("\n"); let results = vecmap(&cases, |&(src, expected_errors, expected_result)| { - let (opt, errors) = parse_recover(block(expression()), src); + let (opt, errors) = parse_recover(block(fresh_statement()), src); let actual = opt.map(|ast| ast.to_string()); let actual = if let Some(s) = &actual { s.to_string() } else { "(none)".to_string() };