Skip to content

Commit

Permalink
chore: convert BlockExpression from tuple to regular struct
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench committed Feb 26, 2024
1 parent 95f4b4d commit 8fd875c
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 71 deletions.
29 changes: 16 additions & 13 deletions aztec_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ fn generate_storage_implementation(module: &mut SortedModule) -> Result<(), Azte
true,
)),
)],
&BlockExpression(false, vec![storage_constructor_statement]),
&BlockExpression { is_unsafe: false, statements: vec![storage_constructor_statement] },
&[],
&return_type(chained_path!("Self")),
));
Expand Down Expand Up @@ -609,25 +609,25 @@ fn transform_function(
// Add access to the storage struct
if storage_defined {
let storage_def = abstract_storage(&ty.to_lowercase(), false);
func.def.body.1.insert(0, storage_def);
func.def.body.statements.insert(0, storage_def);
}

// Insert the context creation as the first action
let create_context = create_context(&context_name, &func.def.parameters)?;
func.def.body.1.splice(0..0, (create_context).iter().cloned());
func.def.body.statements.splice(0..0, (create_context).iter().cloned());

// Add the inputs to the params
let input = create_inputs(&inputs_name);
func.def.parameters.insert(0, input);

// Abstract return types such that they get added to the kernel's return_values
if let Some(return_values) = abstract_return_values(func) {
func.def.body.1.push(return_values);
func.def.body.statements.push(return_values);
}

// Push the finish method call to the end of the function
let finish_def = create_context_finish();
func.def.body.1.push(finish_def);
func.def.body.statements.push(finish_def);

let return_type = create_return_type(&return_type_name);
func.def.return_type = return_type;
Expand All @@ -651,7 +651,7 @@ fn transform_vm_function(
) -> Result<(), AztecMacroError> {
// Push Avm context creation to the beginning of the function
let create_context = create_avm_context()?;
func.def.body.1.insert(0, create_context);
func.def.body.statements.insert(0, create_context);

// We want the function to be seen as a public function
func.def.is_open = true;
Expand All @@ -671,7 +671,7 @@ fn transform_vm_function(
///
/// This will allow developers to access their contract' storage struct in unconstrained functions
fn transform_unconstrained(func: &mut NoirFunction) {
func.def.body.1.insert(0, abstract_storage("Unconstrained", true));
func.def.body.statements.insert(0, abstract_storage("Unconstrained", true));
}

fn collect_crate_structs(crate_id: &CrateId, context: &HirContext) -> Vec<StructId> {
Expand Down Expand Up @@ -1002,15 +1002,15 @@ fn generate_selector_impl(structure: &NoirStruct) -> TypeImpl {
let mut from_signature_path = selector_path.clone();
from_signature_path.segments.push(ident("from_signature"));

let selector_fun_body = BlockExpression(
false,
vec![make_statement(StatementKind::Expression(call(
let selector_fun_body = BlockExpression {
is_unsafe: false,
statements: vec![make_statement(StatementKind::Expression(call(
variable_path(from_signature_path),
vec![expression(ExpressionKind::Literal(Literal::Str(
SIGNATURE_PLACEHOLDER.to_string(),
)))],
)))],
);
};

// Define `FunctionSelector` return type
let return_type =
Expand Down Expand Up @@ -1235,7 +1235,7 @@ fn create_avm_context() -> Result<Statement, AztecMacroError> {
fn abstract_return_values(func: &NoirFunction) -> Option<Statement> {
let current_return_type = func.return_type().typ;
let len = func.def.body.len();
let last_statement = &func.def.body.1[len - 1];
let last_statement = &func.def.body.statements[len - 1];

// TODO: (length, type) => We can limit the size of the array returned to be limited by kernel size
// Doesn't need done until we have settled on a kernel size
Expand Down Expand Up @@ -1492,7 +1492,10 @@ fn create_loop_over(var: Expression, loop_body: Vec<Statement>) -> Statement {

// What will be looped over
// - `hasher.add({ident}[i] as Field)`
let for_loop_block = expression(ExpressionKind::Block(BlockExpression(false, loop_body)));
let for_loop_block = expression(ExpressionKind::Block(BlockExpression {
is_unsafe: false,
statements: loop_body,
}));

// `for i in 0..{ident}.len()`
make_statement(StatementKind::For(ForLoopStatement {
Expand Down
24 changes: 14 additions & 10 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ impl Expression {
// with tuples without calling them. E.g. `if c { t } else { e }(a, b)` is interpreted
// as a sequence of { if, tuple } rather than a function call. This behavior matches rust.
let kind = if matches!(&lhs.kind, ExpressionKind::If(..)) {
ExpressionKind::Block(BlockExpression(
false,
vec![
ExpressionKind::Block(BlockExpression {
is_unsafe: false,
statements: vec![
Statement { kind: StatementKind::Expression(lhs), span },
Statement {
kind: StatementKind::Expression(Expression::new(
Expand All @@ -192,7 +192,7 @@ impl Expression {
span,
},
],
))
})
} else {
ExpressionKind::Call(Box::new(CallExpression { func: Box::new(lhs), arguments }))
};
Expand Down Expand Up @@ -459,19 +459,22 @@ pub struct IndexExpression {
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct BlockExpression(pub bool, pub Vec<Statement>);
pub struct BlockExpression {
pub is_unsafe: bool,
pub statements: Vec<Statement>,
}

impl BlockExpression {
pub fn pop(&mut self) -> Option<StatementKind> {
self.1.pop().map(|stmt| stmt.kind)
self.statements.pop().map(|stmt| stmt.kind)
}

pub fn len(&self) -> usize {
self.1.len()
self.statements.len()
}

pub fn is_empty(&self) -> bool {
self.1.is_empty()
self.statements.is_empty()
}
}

Expand Down Expand Up @@ -540,8 +543,9 @@ impl Display for Literal {

impl Display for BlockExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{{")?;
for statement in &self.1 {
let safety = if self.is_unsafe { "unsafe " } else { "" };
writeln!(f, "{safety}{{")?;
for statement in &self.statements {
let statement = statement.kind.to_string();
for line in statement.lines() {
writeln!(f, " {line}")?;
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/ast/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl NoirFunction {
&mut self.def
}
pub fn number_of_statements(&self) -> usize {
self.def.body.1.len()
self.def.body.statements.len()
}
pub fn span(&self) -> Span {
self.def.span
Expand Down
14 changes: 8 additions & 6 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,13 @@ impl ForRange {
};

let block_span = block.span;
let new_block = BlockExpression(
false,
vec![
let new_block = BlockExpression {
is_unsafe: false,
statements: vec![
let_elem,
Statement { kind: StatementKind::Expression(block), span: block_span },
],
);
};
let new_block = Expression::new(ExpressionKind::Block(new_block), block_span);
let for_loop = Statement {
kind: StatementKind::For(ForLoopStatement {
Expand All @@ -606,8 +606,10 @@ impl ForRange {
span: for_loop_span,
};

let block =
ExpressionKind::Block(BlockExpression(false, vec![let_array, for_loop]));
let block = ExpressionKind::Block(BlockExpression {
is_unsafe: false,
statements: vec![let_array, for_loop],
});
StatementKind::Expression(Expression::new(block, for_loop_span))
}
}
Expand Down
27 changes: 15 additions & 12 deletions compiler/noirc_frontend/src/debug/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ impl DebugInstrumenter {
})
.collect();

self.walk_scope(&mut func.body.1, func.span);
self.walk_scope(&mut func.body.statements, func.span);

// prepend fn params:
func.body.1 = [set_fn_params, func.body.1.clone()].concat();
func.body.statements = [set_fn_params, func.body.statements.clone()].concat();
}

// Modify a vector of statements in-place, adding instrumentation for sets and drops.
Expand Down Expand Up @@ -212,7 +212,10 @@ impl DebugInstrumenter {
pattern: ast::Pattern::Tuple(vars_pattern, let_stmt.pattern.span()),
r#type: ast::UnresolvedType::unspecified(),
expression: ast::Expression {
kind: ast::ExpressionKind::Block(ast::BlockExpression(true, block_stmts)),
kind: ast::ExpressionKind::Block(ast::BlockExpression {
is_unsafe: true,
statements: block_stmts,
}),
span: let_stmt.expression.span,
},
}),
Expand Down Expand Up @@ -299,14 +302,14 @@ impl DebugInstrumenter {
kind: ast::StatementKind::Assign(ast::AssignStatement {
lvalue: assign_stmt.lvalue.clone(),
expression: ast::Expression {
kind: ast::ExpressionKind::Block(ast::BlockExpression(
true,
vec![
kind: ast::ExpressionKind::Block(ast::BlockExpression {
is_unsafe: false,
statements: vec![
ast::Statement { kind: let_kind, span: expression_span },
new_assign_stmt,
ast::Statement { kind: ret_kind, span: expression_span },
],
)),
}),
span: expression_span,
},
}),
Expand All @@ -316,7 +319,7 @@ impl DebugInstrumenter {

fn walk_expr(&mut self, expr: &mut ast::Expression) {
match &mut expr.kind {
ast::ExpressionKind::Block(ast::BlockExpression(_, ref mut statements)) => {
ast::ExpressionKind::Block(ast::BlockExpression { ref mut statements, .. }) => {
self.scope.push(HashMap::default());
self.walk_scope(statements, expr.span);
}
Expand Down Expand Up @@ -387,17 +390,17 @@ impl DebugInstrumenter {

self.walk_expr(&mut for_stmt.block);
for_stmt.block = ast::Expression {
kind: ast::ExpressionKind::Block(ast::BlockExpression(
false,
vec![
kind: ast::ExpressionKind::Block(ast::BlockExpression {
is_unsafe: false,
statements: vec![
set_stmt,
ast::Statement {
kind: ast::StatementKind::Semi(for_stmt.block.clone()),
span: for_stmt.block.span,
},
drop_stmt,
],
)),
}),
span: for_stmt.span,
};
}
Expand Down
8 changes: 4 additions & 4 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ impl<'a> Resolver<'a> {
typ: typ.clone(),
span: name.span(),
}),
body: BlockExpression(false, Vec::new()),
body: BlockExpression { is_unsafe: false, statements: Vec::new() },
span: name.span(),
where_clause: where_clause.to_vec(),
return_type: return_type.clone(),
Expand Down Expand Up @@ -1926,9 +1926,9 @@ impl<'a> Resolver<'a> {
}

fn resolve_block(&mut self, block_expr: BlockExpression) -> HirExpression {
let statements =
self.in_new_scope(|this| vecmap(block_expr.1, |stmt| this.intern_stmt(stmt.kind)));
HirExpression::Block(HirBlockExpression(block_expr.0, statements))
let statements = self
.in_new_scope(|this| vecmap(block_expr.statements, |stmt| this.intern_stmt(stmt.kind)));
HirExpression::Block(HirBlockExpression { is_unsafe: block_expr.is_unsafe, statements })
}

pub fn intern_block(&mut self, block: BlockExpression) -> ExprId {
Expand Down
23 changes: 15 additions & 8 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl<'interner> TypeChecker<'interner> {
HirExpression::Block(block_expr) => {
let mut block_type = Type::Unit;

let allow_unsafe = allow_unsafe_call || block_expr.0;
let allow_unsafe = allow_unsafe_call || block_expr.is_unsafe;
let statements = block_expr.statements();
for (i, stmt) in statements.iter().enumerate() {
let expr_type = self.check_statement(stmt, allow_unsafe);
Expand Down Expand Up @@ -538,8 +538,9 @@ impl<'interner> TypeChecker<'interner> {
&mut self,
id: &ExprId,
mut index_expr: expr::HirIndexExpression,
allow_unsafe_call: bool,
) -> Type {
let index_type = self.check_expression(&index_expr.index, false);
let index_type = self.check_expression(&index_expr.index, allow_unsafe_call);
let span = self.interner.expr_span(&index_expr.index);

index_type.unify(
Expand All @@ -554,7 +555,7 @@ impl<'interner> TypeChecker<'interner> {

// When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many
// times as needed to get the underlying array.
let lhs_type = self.check_expression(&index_expr.collection, false);
let lhs_type = self.check_expression(&index_expr.collection, allow_unsafe_call);
let (new_lhs, lhs_type) = self.insert_auto_dereferences(index_expr.collection, lhs_type);
index_expr.collection = new_lhs;
self.interner.replace_expr(id, HirExpression::Index(index_expr));
Expand Down Expand Up @@ -607,9 +608,14 @@ impl<'interner> TypeChecker<'interner> {
}
}

fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type {
let cond_type = self.check_expression(&if_expr.condition, false);
let then_type = self.check_expression(&if_expr.consequence, false);
fn check_if_expr(
&mut self,
if_expr: &expr::HirIfExpression,
expr_id: &ExprId,
allow_unsafe_call: bool,
) -> Type {
let cond_type = self.check_expression(&if_expr.condition, allow_unsafe_call);
let then_type = self.check_expression(&if_expr.consequence, allow_unsafe_call);

let expr_span = self.interner.expr_span(&if_expr.condition);

Expand All @@ -622,7 +628,7 @@ impl<'interner> TypeChecker<'interner> {
match if_expr.alternative {
None => Type::Unit,
Some(alternative) => {
let else_type = self.check_expression(&alternative, false);
let else_type = self.check_expression(&alternative, allow_unsafe_call);

let expr_span = self.interner.expr_span(expr_id);
self.unify(&then_type, &else_type, || {
Expand Down Expand Up @@ -652,6 +658,7 @@ impl<'interner> TypeChecker<'interner> {
&mut self,
constructor: expr::HirConstructorExpression,
expr_id: &ExprId,
allow_unsafe_call: bool,
) -> Type {
let typ = constructor.r#type;
let generics = constructor.struct_generics;
Expand All @@ -671,7 +678,7 @@ impl<'interner> TypeChecker<'interner> {
// mismatch here as long as we continue typechecking the rest of the program to the best
// of our ability.
if param_name == arg_ident.0.contents {
let arg_type = self.check_expression(&arg, false);
let arg_type = self.check_expression(&arg, allow_unsafe_call);

let span = self.interner.expr_span(expr_id);
self.unify_with_coercions(&arg_type, &param_type, arg, || {
Expand Down
6 changes: 4 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,10 @@ mod test {
expression: expr_id,
};
let stmt_id = interner.push_stmt(HirStatement::Let(let_stmt));
let expr_id =
interner.push_expr(HirExpression::Block(HirBlockExpression(false, vec![stmt_id])));
let expr_id = interner.push_expr(HirExpression::Block(HirBlockExpression {
is_unsafe: false,
statements: vec![stmt_id],
}));
interner.push_expr_location(expr_id, Span::single_char(0), file);

// Create function to enclose the let statement
Expand Down
Loading

0 comments on commit 8fd875c

Please sign in to comment.