From 8fd875c93b90d9c002d7fdb7f96856ec889f1521 Mon Sep 17 00:00:00 2001 From: Tom French Date: Mon, 26 Feb 2024 18:32:12 +0000 Subject: [PATCH] chore: convert `BlockExpression` from tuple to regular struct --- aztec_macros/src/lib.rs | 29 ++++++++++--------- compiler/noirc_frontend/src/ast/expression.rs | 24 ++++++++------- compiler/noirc_frontend/src/ast/function.rs | 2 +- compiler/noirc_frontend/src/ast/statement.rs | 14 +++++---- compiler/noirc_frontend/src/debug/mod.rs | 27 +++++++++-------- .../src/hir/resolution/resolver.rs | 8 ++--- .../noirc_frontend/src/hir/type_check/expr.rs | 23 ++++++++++----- .../noirc_frontend/src/hir/type_check/mod.rs | 6 ++-- compiler/noirc_frontend/src/hir_def/expr.rs | 9 ++++-- .../src/monomorphization/mod.rs | 2 +- compiler/noirc_frontend/src/parser/parser.rs | 21 ++++++++------ tooling/nargo_fmt/src/visitor/expr.rs | 4 +-- 12 files changed, 98 insertions(+), 71 deletions(-) diff --git a/aztec_macros/src/lib.rs b/aztec_macros/src/lib.rs index ac80bdd3587..79b1eabb310 100644 --- a/aztec_macros/src/lib.rs +++ b/aztec_macros/src/lib.rs @@ -574,7 +574,7 @@ fn generate_storage_implementation(module: &mut SortedModule) -> Result<(), Azte true, )), )], - &BlockExpression(false, vec![storage_constructor_statement]), + &BlockExpression { is_unsafe: false, statements: vec![storage_constructor_statement] }, &[], &return_type(chained_path!("Self")), )); @@ -609,12 +609,12 @@ fn transform_function( // Add access to the storage struct if storage_defined { let storage_def = abstract_storage(&ty.to_lowercase(), false); - func.def.body.1.insert(0, storage_def); + func.def.body.statements.insert(0, storage_def); } // Insert the context creation as the first action let create_context = create_context(&context_name, &func.def.parameters)?; - func.def.body.1.splice(0..0, (create_context).iter().cloned()); + func.def.body.statements.splice(0..0, (create_context).iter().cloned()); // Add the inputs to the params let input = create_inputs(&inputs_name); @@ -622,12 +622,12 @@ fn transform_function( // Abstract return types such that they get added to the kernel's return_values if let Some(return_values) = abstract_return_values(func) { - func.def.body.1.push(return_values); + func.def.body.statements.push(return_values); } // Push the finish method call to the end of the function let finish_def = create_context_finish(); - func.def.body.1.push(finish_def); + func.def.body.statements.push(finish_def); let return_type = create_return_type(&return_type_name); func.def.return_type = return_type; @@ -651,7 +651,7 @@ fn transform_vm_function( ) -> Result<(), AztecMacroError> { // Push Avm context creation to the beginning of the function let create_context = create_avm_context()?; - func.def.body.1.insert(0, create_context); + func.def.body.statements.insert(0, create_context); // We want the function to be seen as a public function func.def.is_open = true; @@ -671,7 +671,7 @@ fn transform_vm_function( /// /// This will allow developers to access their contract' storage struct in unconstrained functions fn transform_unconstrained(func: &mut NoirFunction) { - func.def.body.1.insert(0, abstract_storage("Unconstrained", true)); + func.def.body.statements.insert(0, abstract_storage("Unconstrained", true)); } fn collect_crate_structs(crate_id: &CrateId, context: &HirContext) -> Vec { @@ -1002,15 +1002,15 @@ fn generate_selector_impl(structure: &NoirStruct) -> TypeImpl { let mut from_signature_path = selector_path.clone(); from_signature_path.segments.push(ident("from_signature")); - let selector_fun_body = BlockExpression( - false, - vec![make_statement(StatementKind::Expression(call( + let selector_fun_body = BlockExpression { + is_unsafe: false, + statements: vec![make_statement(StatementKind::Expression(call( variable_path(from_signature_path), vec![expression(ExpressionKind::Literal(Literal::Str( SIGNATURE_PLACEHOLDER.to_string(), )))], )))], - ); + }; // Define `FunctionSelector` return type let return_type = @@ -1235,7 +1235,7 @@ fn create_avm_context() -> Result { fn abstract_return_values(func: &NoirFunction) -> Option { let current_return_type = func.return_type().typ; let len = func.def.body.len(); - let last_statement = &func.def.body.1[len - 1]; + let last_statement = &func.def.body.statements[len - 1]; // TODO: (length, type) => We can limit the size of the array returned to be limited by kernel size // Doesn't need done until we have settled on a kernel size @@ -1492,7 +1492,10 @@ fn create_loop_over(var: Expression, loop_body: Vec) -> Statement { // What will be looped over // - `hasher.add({ident}[i] as Field)` - let for_loop_block = expression(ExpressionKind::Block(BlockExpression(false, loop_body))); + let for_loop_block = expression(ExpressionKind::Block(BlockExpression { + is_unsafe: false, + statements: loop_body, + })); // `for i in 0..{ident}.len()` make_statement(StatementKind::For(ForLoopStatement { diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index e3e4268aad5..dd4eca24784 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -180,9 +180,9 @@ impl Expression { // with tuples without calling them. E.g. `if c { t } else { e }(a, b)` is interpreted // as a sequence of { if, tuple } rather than a function call. This behavior matches rust. let kind = if matches!(&lhs.kind, ExpressionKind::If(..)) { - ExpressionKind::Block(BlockExpression( - false, - vec![ + ExpressionKind::Block(BlockExpression { + is_unsafe: false, + statements: vec![ Statement { kind: StatementKind::Expression(lhs), span }, Statement { kind: StatementKind::Expression(Expression::new( @@ -192,7 +192,7 @@ impl Expression { span, }, ], - )) + }) } else { ExpressionKind::Call(Box::new(CallExpression { func: Box::new(lhs), arguments })) }; @@ -459,19 +459,22 @@ pub struct IndexExpression { } #[derive(Debug, PartialEq, Eq, Clone)] -pub struct BlockExpression(pub bool, pub Vec); +pub struct BlockExpression { + pub is_unsafe: bool, + pub statements: Vec, +} impl BlockExpression { pub fn pop(&mut self) -> Option { - self.1.pop().map(|stmt| stmt.kind) + self.statements.pop().map(|stmt| stmt.kind) } pub fn len(&self) -> usize { - self.1.len() + self.statements.len() } pub fn is_empty(&self) -> bool { - self.1.is_empty() + self.statements.is_empty() } } @@ -540,8 +543,9 @@ impl Display for Literal { impl Display for BlockExpression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "{{")?; - for statement in &self.1 { + let safety = if self.is_unsafe { "unsafe " } else { "" }; + writeln!(f, "{safety}{{")?; + for statement in &self.statements { let statement = statement.kind.to_string(); for line in statement.lines() { writeln!(f, " {line}")?; diff --git a/compiler/noirc_frontend/src/ast/function.rs b/compiler/noirc_frontend/src/ast/function.rs index b4c74f9ba83..3e8b78c1312 100644 --- a/compiler/noirc_frontend/src/ast/function.rs +++ b/compiler/noirc_frontend/src/ast/function.rs @@ -83,7 +83,7 @@ impl NoirFunction { &mut self.def } pub fn number_of_statements(&self) -> usize { - self.def.body.1.len() + self.def.body.statements.len() } pub fn span(&self) -> Span { self.def.span diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index b9350c92bbe..654554fe0e8 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -588,13 +588,13 @@ impl ForRange { }; let block_span = block.span; - let new_block = BlockExpression( - false, - vec![ + let new_block = BlockExpression { + is_unsafe: false, + statements: vec![ let_elem, Statement { kind: StatementKind::Expression(block), span: block_span }, ], - ); + }; let new_block = Expression::new(ExpressionKind::Block(new_block), block_span); let for_loop = Statement { kind: StatementKind::For(ForLoopStatement { @@ -606,8 +606,10 @@ impl ForRange { span: for_loop_span, }; - let block = - ExpressionKind::Block(BlockExpression(false, vec![let_array, for_loop])); + let block = ExpressionKind::Block(BlockExpression { + is_unsafe: false, + statements: vec![let_array, for_loop], + }); StatementKind::Expression(Expression::new(block, for_loop_span)) } } diff --git a/compiler/noirc_frontend/src/debug/mod.rs b/compiler/noirc_frontend/src/debug/mod.rs index fcf577d1ac6..983ec1ccc81 100644 --- a/compiler/noirc_frontend/src/debug/mod.rs +++ b/compiler/noirc_frontend/src/debug/mod.rs @@ -93,10 +93,10 @@ impl DebugInstrumenter { }) .collect(); - self.walk_scope(&mut func.body.1, func.span); + self.walk_scope(&mut func.body.statements, func.span); // prepend fn params: - func.body.1 = [set_fn_params, func.body.1.clone()].concat(); + func.body.statements = [set_fn_params, func.body.statements.clone()].concat(); } // Modify a vector of statements in-place, adding instrumentation for sets and drops. @@ -212,7 +212,10 @@ impl DebugInstrumenter { pattern: ast::Pattern::Tuple(vars_pattern, let_stmt.pattern.span()), r#type: ast::UnresolvedType::unspecified(), expression: ast::Expression { - kind: ast::ExpressionKind::Block(ast::BlockExpression(true, block_stmts)), + kind: ast::ExpressionKind::Block(ast::BlockExpression { + is_unsafe: true, + statements: block_stmts, + }), span: let_stmt.expression.span, }, }), @@ -299,14 +302,14 @@ impl DebugInstrumenter { kind: ast::StatementKind::Assign(ast::AssignStatement { lvalue: assign_stmt.lvalue.clone(), expression: ast::Expression { - kind: ast::ExpressionKind::Block(ast::BlockExpression( - true, - vec![ + kind: ast::ExpressionKind::Block(ast::BlockExpression { + is_unsafe: false, + statements: vec![ ast::Statement { kind: let_kind, span: expression_span }, new_assign_stmt, ast::Statement { kind: ret_kind, span: expression_span }, ], - )), + }), span: expression_span, }, }), @@ -316,7 +319,7 @@ impl DebugInstrumenter { fn walk_expr(&mut self, expr: &mut ast::Expression) { match &mut expr.kind { - ast::ExpressionKind::Block(ast::BlockExpression(_, ref mut statements)) => { + ast::ExpressionKind::Block(ast::BlockExpression { ref mut statements, .. }) => { self.scope.push(HashMap::default()); self.walk_scope(statements, expr.span); } @@ -387,9 +390,9 @@ impl DebugInstrumenter { self.walk_expr(&mut for_stmt.block); for_stmt.block = ast::Expression { - kind: ast::ExpressionKind::Block(ast::BlockExpression( - false, - vec![ + kind: ast::ExpressionKind::Block(ast::BlockExpression { + is_unsafe: false, + statements: vec![ set_stmt, ast::Statement { kind: ast::StatementKind::Semi(for_stmt.block.clone()), @@ -397,7 +400,7 @@ impl DebugInstrumenter { }, drop_stmt, ], - )), + }), span: for_stmt.span, }; } diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 0597fcdd9e0..ad469f8146d 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -244,7 +244,7 @@ impl<'a> Resolver<'a> { typ: typ.clone(), span: name.span(), }), - body: BlockExpression(false, Vec::new()), + body: BlockExpression { is_unsafe: false, statements: Vec::new() }, span: name.span(), where_clause: where_clause.to_vec(), return_type: return_type.clone(), @@ -1926,9 +1926,9 @@ impl<'a> Resolver<'a> { } fn resolve_block(&mut self, block_expr: BlockExpression) -> HirExpression { - let statements = - self.in_new_scope(|this| vecmap(block_expr.1, |stmt| this.intern_stmt(stmt.kind))); - HirExpression::Block(HirBlockExpression(block_expr.0, statements)) + let statements = self + .in_new_scope(|this| vecmap(block_expr.statements, |stmt| this.intern_stmt(stmt.kind))); + HirExpression::Block(HirBlockExpression { is_unsafe: block_expr.is_unsafe, statements }) } pub fn intern_block(&mut self, block: BlockExpression) -> ExprId { diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index fa242be5444..6ca60582f0d 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -256,7 +256,7 @@ impl<'interner> TypeChecker<'interner> { HirExpression::Block(block_expr) => { let mut block_type = Type::Unit; - let allow_unsafe = allow_unsafe_call || block_expr.0; + let allow_unsafe = allow_unsafe_call || block_expr.is_unsafe; let statements = block_expr.statements(); for (i, stmt) in statements.iter().enumerate() { let expr_type = self.check_statement(stmt, allow_unsafe); @@ -538,8 +538,9 @@ impl<'interner> TypeChecker<'interner> { &mut self, id: &ExprId, mut index_expr: expr::HirIndexExpression, + allow_unsafe_call: bool, ) -> Type { - let index_type = self.check_expression(&index_expr.index, false); + let index_type = self.check_expression(&index_expr.index, allow_unsafe_call); let span = self.interner.expr_span(&index_expr.index); index_type.unify( @@ -554,7 +555,7 @@ impl<'interner> TypeChecker<'interner> { // When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many // times as needed to get the underlying array. - let lhs_type = self.check_expression(&index_expr.collection, false); + let lhs_type = self.check_expression(&index_expr.collection, allow_unsafe_call); let (new_lhs, lhs_type) = self.insert_auto_dereferences(index_expr.collection, lhs_type); index_expr.collection = new_lhs; self.interner.replace_expr(id, HirExpression::Index(index_expr)); @@ -607,9 +608,14 @@ impl<'interner> TypeChecker<'interner> { } } - fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type { - let cond_type = self.check_expression(&if_expr.condition, false); - let then_type = self.check_expression(&if_expr.consequence, false); + fn check_if_expr( + &mut self, + if_expr: &expr::HirIfExpression, + expr_id: &ExprId, + allow_unsafe_call: bool, + ) -> Type { + let cond_type = self.check_expression(&if_expr.condition, allow_unsafe_call); + let then_type = self.check_expression(&if_expr.consequence, allow_unsafe_call); let expr_span = self.interner.expr_span(&if_expr.condition); @@ -622,7 +628,7 @@ impl<'interner> TypeChecker<'interner> { match if_expr.alternative { None => Type::Unit, Some(alternative) => { - let else_type = self.check_expression(&alternative, false); + let else_type = self.check_expression(&alternative, allow_unsafe_call); let expr_span = self.interner.expr_span(expr_id); self.unify(&then_type, &else_type, || { @@ -652,6 +658,7 @@ impl<'interner> TypeChecker<'interner> { &mut self, constructor: expr::HirConstructorExpression, expr_id: &ExprId, + allow_unsafe_call: bool, ) -> Type { let typ = constructor.r#type; let generics = constructor.struct_generics; @@ -671,7 +678,7 @@ impl<'interner> TypeChecker<'interner> { // mismatch here as long as we continue typechecking the rest of the program to the best // of our ability. if param_name == arg_ident.0.contents { - let arg_type = self.check_expression(&arg, false); + let arg_type = self.check_expression(&arg, allow_unsafe_call); let span = self.interner.expr_span(expr_id); self.unify_with_coercions(&arg_type, ¶m_type, arg, || { diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index ad34bf217e4..57e2e47ed8c 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -297,8 +297,10 @@ mod test { expression: expr_id, }; let stmt_id = interner.push_stmt(HirStatement::Let(let_stmt)); - let expr_id = - interner.push_expr(HirExpression::Block(HirBlockExpression(false, vec![stmt_id]))); + let expr_id = interner.push_expr(HirExpression::Block(HirBlockExpression { + is_unsafe: false, + statements: vec![stmt_id], + })); interner.push_expr_location(expr_id, Span::single_char(0), file); // Create function to enclose the let statement diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index c1603770240..f094e1058dd 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -36,7 +36,7 @@ pub enum HirExpression { impl HirExpression { /// Returns an empty block expression pub const fn empty_block() -> HirExpression { - HirExpression::Block(HirBlockExpression(false, vec![])) + HirExpression::Block(HirBlockExpression { is_unsafe: false, statements: vec![] }) } } @@ -247,11 +247,14 @@ pub struct HirIndexExpression { } #[derive(Debug, Clone)] -pub struct HirBlockExpression(pub bool, pub Vec); +pub struct HirBlockExpression { + pub is_unsafe: bool, + pub statements: Vec, +} impl HirBlockExpression { pub fn statements(&self) -> &[StmtId] { - &self.1 + &self.statements } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 87bf3d47c1a..0ac629af18e 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -373,7 +373,7 @@ impl<'interner> Monomorphizer<'interner> { } }, HirExpression::Literal(HirLiteral::Unit) => ast::Expression::Block(vec![]), - HirExpression::Block(block) => self.block(block.1), + HirExpression::Block(block) => self.block(block.statements), HirExpression::Prefix(prefix) => { let location = self.interner.expr_location(&expr); diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 6e4c3c2b1a9..38146a67e7c 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -354,7 +354,7 @@ fn block<'a>( |span| vec![Statement { kind: StatementKind::Error, span }], )), ) - .map(|(is_unsafe, statements)| BlockExpression(is_unsafe, statements)) + .map(|(is_unsafe, statements)| BlockExpression { is_unsafe, statements }) } fn check_statements_require_semicolon( @@ -1000,10 +1000,13 @@ where // Wrap the inner `if` expression in a block expression. // i.e. rewrite the sugared form `if cond1 {} else if cond2 {}` as `if cond1 {} else { if cond2 {} }`. let if_expression = Expression::new(kind, span); - let desugared_else = BlockExpression( - false, - vec![Statement { kind: StatementKind::Expression(if_expression), span }], - ); + let desugared_else = BlockExpression { + is_unsafe: false, + statements: vec![Statement { + kind: StatementKind::Expression(if_expression), + span, + }], + }; Expression::new(ExpressionKind::Block(desugared_else), span) })); @@ -1295,13 +1298,13 @@ mod test { // Regression for #1310: this should be parsed as a block and not a function call let res = parse_with(block(fresh_statement()), "{ if true { 1 } else { 2 } (3, 4) }").unwrap(); - match unwrap_expr(&res.1.last().unwrap().kind) { + match unwrap_expr(&res.statements.last().unwrap().kind) { // The `if` followed by a tuple is currently creates a block around both in case // there was none to start with, so there is an extra block here. ExpressionKind::Block(block) => { - assert_eq!(block.1.len(), 2); - assert!(matches!(unwrap_expr(&block.1[0].kind), ExpressionKind::If(_))); - assert!(matches!(unwrap_expr(&block.1[1].kind), ExpressionKind::Tuple(_))); + assert_eq!(block.statements.len(), 2); + assert!(matches!(unwrap_expr(&block.statements[0].kind), ExpressionKind::If(_))); + assert!(matches!(unwrap_expr(&block.statements[1].kind), ExpressionKind::Tuple(_))); } _ => unreachable!(), } diff --git a/tooling/nargo_fmt/src/visitor/expr.rs b/tooling/nargo_fmt/src/visitor/expr.rs index 3e432326019..f9836adda18 100644 --- a/tooling/nargo_fmt/src/visitor/expr.rs +++ b/tooling/nargo_fmt/src/visitor/expr.rs @@ -119,11 +119,11 @@ impl FmtVisitor<'_> { self.last_position = block_span.start() + 1; // `{` self.push_str("{"); - self.trim_spaces_after_opening_brace(&block.1); + self.trim_spaces_after_opening_brace(&block.statements); self.indent.block_indent(self.config); - self.visit_stmts(block.1); + self.visit_stmts(block.statements); let span = (self.last_position..block_span.end() - 1).into(); self.close_block(span);