diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 639d4d8f763..9fe78f40c59 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -413,7 +413,14 @@ pub enum LValue { } #[derive(Debug, PartialEq, Eq, Clone)] -pub struct ConstrainStatement(pub Expression, pub Option); +pub struct ConstrainStatement(pub Expression, pub Option, pub ConstrainKind); + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ConstrainKind { + Assert, + AssertEq, + Constrain, +} #[derive(Debug, PartialEq, Eq, Clone)] pub enum Pattern { diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index f58f7315e8c..cee824a51c9 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -37,11 +37,12 @@ use crate::lexer::Lexer; use crate::parser::{force, ignore_then_commit, statement_recovery}; use crate::token::{Attribute, Attributes, Keyword, SecondaryAttribute, Token, TokenKind}; use crate::{ - BinaryOp, BinaryOpKind, BlockExpression, ConstrainStatement, Distinctness, FunctionDefinition, - FunctionReturnType, FunctionVisibility, Ident, IfExpression, InfixExpression, LValue, Lambda, - Literal, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, PathKind, - Pattern, Recoverable, Statement, TraitBound, TraitImplItem, TraitItem, TypeImpl, UnaryOp, - UnresolvedTraitConstraint, UnresolvedTypeExpression, UseTree, UseTreeKind, Visibility, + BinaryOp, BinaryOpKind, BlockExpression, ConstrainKind, ConstrainStatement, Distinctness, + FunctionDefinition, FunctionReturnType, FunctionVisibility, Ident, IfExpression, + InfixExpression, LValue, Lambda, Literal, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, + NoirTypeAlias, Path, PathKind, Pattern, Recoverable, Statement, TraitBound, TraitImplItem, + TraitItem, TypeImpl, UnaryOp, UnresolvedTraitConstraint, UnresolvedTypeExpression, UseTree, + UseTreeKind, Visibility, }; use chumsky::prelude::*; @@ -821,7 +822,7 @@ where keyword(Keyword::Constrain).labelled(ParsingRuleLabel::Statement), expr_parser, ) - .map(|expr| StatementKind::Constrain(ConstrainStatement(expr, None))) + .map(|expr| StatementKind::Constrain(ConstrainStatement(expr, None, ConstrainKind::Constrain))) .validate(|expr, span, emit| { emit(ParserError::with_reason(ParserErrorReason::ConstrainDeprecated, span)); expr @@ -849,7 +850,11 @@ where } } - StatementKind::Constrain(ConstrainStatement(condition, message_str)) + StatementKind::Constrain(ConstrainStatement( + condition, + message_str, + ConstrainKind::Assert, + )) }) } @@ -880,7 +885,11 @@ where emit(ParserError::with_reason(ParserErrorReason::AssertMessageNotString, span)); } } - StatementKind::Constrain(ConstrainStatement(predicate, message_str)) + StatementKind::Constrain(ConstrainStatement( + predicate, + message_str, + ConstrainKind::AssertEq, + )) }) } @@ -2014,7 +2023,7 @@ mod test { match parse_with(assertion(expression()), "assert(x == y, \"assertion message\")").unwrap() { - StatementKind::Constrain(ConstrainStatement(_, message)) => { + StatementKind::Constrain(ConstrainStatement(_, message, _)) => { assert_eq!(message, Some("assertion message".to_owned())); } _ => unreachable!(), @@ -2038,7 +2047,7 @@ mod test { match parse_with(assertion_eq(expression()), "assert_eq(x, y, \"assertion message\")") .unwrap() { - StatementKind::Constrain(ConstrainStatement(_, message)) => { + StatementKind::Constrain(ConstrainStatement(_, message, _)) => { assert_eq!(message, Some("assertion message".to_owned())); } _ => unreachable!(), diff --git a/tooling/nargo_fmt/src/utils.rs b/tooling/nargo_fmt/src/utils.rs index 68c6b8bd1e3..5b2cf2f3c47 100644 --- a/tooling/nargo_fmt/src/utils.rs +++ b/tooling/nargo_fmt/src/utils.rs @@ -218,7 +218,7 @@ impl Item for Expression { } fn format(self, visitor: &FmtVisitor) -> String { - visitor.format_subexpr(self) + visitor.format_sub_expr(self) } } @@ -232,7 +232,7 @@ impl Item for (Ident, Expression) { let (name, expr) = self; let name = name.0.contents; - let expr = visitor.format_subexpr(expr); + let expr = visitor.format_sub_expr(expr); if name == expr { name diff --git a/tooling/nargo_fmt/src/visitor/expr.rs b/tooling/nargo_fmt/src/visitor/expr.rs index 66d586888a6..8f855cd6157 100644 --- a/tooling/nargo_fmt/src/visitor/expr.rs +++ b/tooling/nargo_fmt/src/visitor/expr.rs @@ -20,7 +20,7 @@ impl FmtVisitor<'_> { self.last_position = span.end(); } - pub(crate) fn format_subexpr(&self, expression: Expression) -> String { + pub(crate) fn format_sub_expr(&self, expression: Expression) -> String { self.format_expr(expression, ExpressionType::SubExpression) } @@ -49,24 +49,24 @@ impl FmtVisitor<'_> { } }; - format!("{op}{}", self.format_subexpr(prefix.rhs)) + format!("{op}{}", self.format_sub_expr(prefix.rhs)) } ExpressionKind::Cast(cast) => { - format!("{} as {}", self.format_subexpr(cast.lhs), cast.r#type) + format!("{} as {}", self.format_sub_expr(cast.lhs), cast.r#type) } ExpressionKind::Infix(infix) => { format!( "{} {} {}", - self.format_subexpr(infix.lhs), + self.format_sub_expr(infix.lhs), infix.operator.contents.as_string(), - self.format_subexpr(infix.rhs) + self.format_sub_expr(infix.rhs) ) } ExpressionKind::Call(call_expr) => { let args_span = self.span_before(call_expr.func.span.end()..span.end(), Token::LeftParen); - let callee = self.format_subexpr(*call_expr.func); + let callee = self.format_sub_expr(*call_expr.func); let args = format_parens(self.fork(), false, call_expr.arguments, args_span); format!("{callee}{args}") @@ -77,21 +77,21 @@ impl FmtVisitor<'_> { Token::LeftParen, ); - let object = self.format_subexpr(method_call_expr.object); + let object = self.format_sub_expr(method_call_expr.object); let method = method_call_expr.method_name.to_string(); let args = format_parens(self.fork(), false, method_call_expr.arguments, args_span); format!("{object}.{method}{args}") } ExpressionKind::MemberAccess(member_access_expr) => { - let lhs_str = self.format_subexpr(member_access_expr.lhs); + let lhs_str = self.format_sub_expr(member_access_expr.lhs); format!("{}.{}", lhs_str, member_access_expr.rhs) } ExpressionKind::Index(index_expr) => { let index_span = self .span_before(index_expr.collection.span.end()..span.end(), Token::LeftBracket); - let collection = self.format_subexpr(index_expr.collection); + let collection = self.format_sub_expr(index_expr.collection); let index = format_brackets(self.fork(), false, vec![index_expr.index], index_span); format!("{collection}{index}") @@ -104,8 +104,8 @@ impl FmtVisitor<'_> { self.slice(span).to_string() } Literal::Array(ArrayLiteral::Repeated { repeated_element, length }) => { - let repeated = self.format_subexpr(*repeated_element); - let length = self.format_subexpr(*length); + let repeated = self.format_sub_expr(*repeated_element); + let length = self.format_sub_expr(*length); format!("[{repeated}; {length}]") } @@ -139,7 +139,7 @@ impl FmtVisitor<'_> { } if !leading.contains("//") && !trailing.contains("//") { - let sub_expr = self.format_subexpr(*sub_expr); + let sub_expr = self.format_sub_expr(*sub_expr); format!("({leading}{sub_expr}{trailing})") } else { let mut visitor = self.fork(); @@ -148,7 +148,7 @@ impl FmtVisitor<'_> { visitor.indent.block_indent(self.config); let nested_indent = visitor.indent.to_string_with_newline(); - let sub_expr = visitor.format_subexpr(*sub_expr); + let sub_expr = visitor.format_sub_expr(*sub_expr); let mut result = String::new(); result.push('('); @@ -192,13 +192,14 @@ impl FmtVisitor<'_> { self.format_if(*if_expr) } - _ => self.slice(span).to_string(), + ExpressionKind::Variable(_) | ExpressionKind::Lambda(_) => self.slice(span).to_string(), + ExpressionKind::Error => unreachable!(), } } fn format_if(&self, if_expr: IfExpression) -> String { - let condition_str = self.format_subexpr(if_expr.condition); - let consequence_str = self.format_subexpr(if_expr.consequence); + let condition_str = self.format_sub_expr(if_expr.condition); + let consequence_str = self.format_sub_expr(if_expr.consequence); let mut result = format!("if {condition_str} {consequence_str}"); @@ -219,8 +220,8 @@ impl FmtVisitor<'_> { } fn format_if_single_line(&self, if_expr: IfExpression) -> Option { - let condition_str = self.format_subexpr(if_expr.condition); - let consequence_str = self.format_subexpr(extract_simple_expr(if_expr.consequence)?); + let condition_str = self.format_sub_expr(if_expr.condition); + let consequence_str = self.format_sub_expr(extract_simple_expr(if_expr.consequence)?); let if_str = if let Some(alternative) = if_expr.alternative { let alternative_str = if let Some(ExpressionKind::If(_)) = diff --git a/tooling/nargo_fmt/src/visitor/stmt.rs b/tooling/nargo_fmt/src/visitor/stmt.rs index ca28c8a5c06..0a814ebd136 100644 --- a/tooling/nargo_fmt/src/visitor/stmt.rs +++ b/tooling/nargo_fmt/src/visitor/stmt.rs @@ -1,6 +1,6 @@ use std::iter::zip; -use noirc_frontend::{Statement, StatementKind}; +use noirc_frontend::{ConstrainKind, ConstrainStatement, ExpressionKind, Statement, StatementKind}; use super::ExpressionType; @@ -28,8 +28,35 @@ impl super::FmtVisitor<'_> { self.push_rewrite(format!("{let_str} {expr_str};"), span); } + StatementKind::Constrain(ConstrainStatement(expr, message, kind)) => { + let message = + message.map_or(String::new(), |message| format!(", \"{message}\"")); + let constrain = match kind { + ConstrainKind::Assert => { + let assertion = self.format_sub_expr(expr); + + format!("assert({assertion}{message});") + } + ConstrainKind::AssertEq => { + if let ExpressionKind::Infix(infix) = expr.kind { + let lhs = self.format_sub_expr(infix.lhs); + let rhs = self.format_sub_expr(infix.rhs); + + format!("assert_eq({lhs}, {rhs}{message});") + } else { + unreachable!() + } + } + ConstrainKind::Constrain => { + let expr = self.format_sub_expr(expr); + format!("constrain {expr};") + } + }; + + self.push_rewrite(constrain, span); + } + StatementKind::Assign(_) | StatementKind::For(_) => self.format_missing(span.end()), StatementKind::Error => unreachable!(), - _ => self.format_missing(span.end()), } self.last_position = span.end(); diff --git a/tooling/nargo_fmt/tests/expected/add.nr b/tooling/nargo_fmt/tests/expected/add.nr index 6f2892942c1..341ed06f3e6 100644 --- a/tooling/nargo_fmt/tests/expected/add.nr +++ b/tooling/nargo_fmt/tests/expected/add.nr @@ -3,5 +3,5 @@ fn main(mut x: u32, y: u32, z: u32) { assert(x == z); x *= 8; - assert(x>9); + assert(x > 9); } diff --git a/tooling/nargo_fmt/tests/expected/call.nr b/tooling/nargo_fmt/tests/expected/call.nr index 8a104b6d5e0..8f627ed1223 100644 --- a/tooling/nargo_fmt/tests/expected/call.nr +++ b/tooling/nargo_fmt/tests/expected/call.nr @@ -18,4 +18,14 @@ fn foo() { my_function(some_function(10, "arg1", another_function()), another_func(20, some_function(), 30)); outer_function(some_function(), another_function(some_function(), some_value)); + + assert_eq(x, y); + + assert_eq(x, y, "message"); + + assert(x); + + assert(x, "message"); + + assert(x == y); } diff --git a/tooling/nargo_fmt/tests/input/call.nr b/tooling/nargo_fmt/tests/input/call.nr index f76157e83ca..24e61c806cc 100644 --- a/tooling/nargo_fmt/tests/input/call.nr +++ b/tooling/nargo_fmt/tests/input/call.nr @@ -32,4 +32,14 @@ fn foo() { another_function( some_function(), some_value) ); + + assert_eq( x, y ); + + assert_eq( x, y, "message" ); + + assert( x ); + + assert( x, "message" ); + + assert( x == y ); }