Skip to content

Commit

Permalink
fix!: Make for loops a statement (#2975)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Oct 5, 2023
1 parent 982380e commit 0e266eb
Show file tree
Hide file tree
Showing 11 changed files with 245 additions and 230 deletions.
20 changes: 0 additions & 20 deletions compiler/noirc_frontend/src/ast/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pub enum ExpressionKind {
MemberAccess(Box<MemberAccessExpression>),
Cast(Box<CastExpression>),
Infix(Box<InfixExpression>),
For(Box<ForExpression>),
If(Box<IfExpression>),
Variable(Path),
Tuple(Vec<Expression>),
Expand Down Expand Up @@ -181,14 +180,6 @@ impl Expression {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ForExpression {
pub identifier: Ident,
pub start_range: Expression,
pub end_range: Expression,
pub block: Expression,
}

pub type BinaryOp = Spanned<BinaryOpKind>;

#[derive(PartialEq, PartialOrd, Eq, Ord, Hash, Debug, Copy, Clone)]
Expand Down Expand Up @@ -469,7 +460,6 @@ impl Display for ExpressionKind {
MethodCall(call) => call.fmt(f),
Cast(cast) => cast.fmt(f),
Infix(infix) => infix.fmt(f),
For(for_loop) => for_loop.fmt(f),
If(if_expr) => if_expr.fmt(f),
Variable(path) => path.fmt(f),
Constructor(constructor) => constructor.fmt(f),
Expand Down Expand Up @@ -603,16 +593,6 @@ impl Display for BinaryOpKind {
}
}

impl Display for ForExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"for {} in {} .. {} {}",
self.identifier, self.start_range, self.end_range, self.block
)
}
}

impl Display for IfExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "if {} {}", self.condition, self.consequence)?;
Expand Down
26 changes: 23 additions & 3 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub enum Statement {
Constrain(ConstrainStatement),
Expression(Expression),
Assign(AssignStatement),
For(ForLoopStatement),
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand Down Expand Up @@ -65,13 +66,13 @@ impl Statement {
}
self
}
// A semicolon on a for loop is optional and does nothing
Statement::For(_) => self,

Statement::Expression(expr) => {
match (&expr.kind, semi, last_statement_in_block) {
// Semicolons are optional for these expressions
(ExpressionKind::Block(_), semi, _)
| (ExpressionKind::For(_), semi, _)
| (ExpressionKind::If(_), semi, _) => {
(ExpressionKind::Block(_), semi, _) | (ExpressionKind::If(_), semi, _) => {
if semi.is_some() {
Statement::Semi(expr)
} else {
Expand Down Expand Up @@ -459,13 +460,22 @@ impl LValue {
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct ForLoopStatement {
pub identifier: Ident,
pub start_range: Expression,
pub end_range: Expression,
pub block: Expression,
}

impl Display for Statement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Statement::Let(let_statement) => let_statement.fmt(f),
Statement::Constrain(constrain) => constrain.fmt(f),
Statement::Expression(expression) => expression.fmt(f),
Statement::Assign(assign) => assign.fmt(f),
Statement::For(for_loop) => for_loop.fmt(f),
Statement::Semi(semi) => write!(f, "{semi};"),
Statement::Error => write!(f, "Error"),
}
Expand Down Expand Up @@ -544,3 +554,13 @@ impl Display for Pattern {
}
}
}

impl Display for ForLoopStatement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"for {} in {} .. {} {}",
self.identifier, self.start_range, self.end_range, self.block
)
}
}
85 changes: 35 additions & 50 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
// XXX: Resolver does not check for unused functions
use crate::hir_def::expr::{
HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar,
HirCastExpression, HirConstructorExpression, HirExpression, HirForExpression, HirIdent,
HirIfExpression, HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral,
HirMemberAccess, HirMethodCallExpression, HirPrefixExpression,
HirCastExpression, HirConstructorExpression, HirExpression, HirIdent, HirIfExpression,
HirIndexExpression, HirInfixExpression, HirLambda, HirLiteral, HirMemberAccess,
HirMethodCallExpression, HirPrefixExpression,
};

use crate::hir_def::traits::{Trait, TraitConstraint};
Expand All @@ -26,7 +26,7 @@ use std::rc::Rc;

use crate::graph::CrateId;
use crate::hir::def_map::{LocalModuleId, ModuleDefId, TryFromModuleDefId, MAIN_FUNCTION};
use crate::hir_def::stmt::{HirAssignStatement, HirLValue, HirPattern};
use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern};
use crate::node_interner::{
DefinitionId, DefinitionKind, ExprId, FuncId, NodeInterner, StmtId, StructId, TraitId,
};
Expand Down Expand Up @@ -957,6 +957,25 @@ impl<'a> Resolver<'a> {
let stmt = HirAssignStatement { lvalue: identifier, expression };
HirStatement::Assign(stmt)
}
Statement::For(for_loop) => {
let start_range = self.resolve_expression(for_loop.start_range);
let end_range = self.resolve_expression(for_loop.end_range);
let (identifier, block) = (for_loop.identifier, for_loop.block);

// TODO: For loop variables are currently mutable by default since we haven't
// yet implemented syntax for them to be optionally mutable.
let (identifier, block) = self.in_new_scope(|this| {
let decl = this.add_variable_decl(
identifier,
false,
true,
DefinitionKind::Local(None),
);
(decl, this.resolve_expression(block))
});

HirStatement::For(HirForStatement { start_range, end_range, block, identifier })
}
Statement::Error => HirStatement::Error,
}
}
Expand Down Expand Up @@ -1169,30 +1188,6 @@ impl<'a> Resolver<'a> {
lhs: self.resolve_expression(cast_expr.lhs),
r#type: self.resolve_type(cast_expr.r#type),
}),
ExpressionKind::For(for_expr) => {
let start_range = self.resolve_expression(for_expr.start_range);
let end_range = self.resolve_expression(for_expr.end_range);
let (identifier, block) = (for_expr.identifier, for_expr.block);

// TODO: For loop variables are currently mutable by default since we haven't
// yet implemented syntax for them to be optionally mutable.
let (identifier, block_id) = self.in_new_scope(|this| {
let decl = this.add_variable_decl(
identifier,
false,
true,
DefinitionKind::Local(None),
);
(decl, this.resolve_expression(block))
});

HirExpression::For(HirForExpression {
start_range,
end_range,
block: block_id,
identifier,
})
}
ExpressionKind::If(if_expr) => HirExpression::If(HirIfExpression {
condition: self.resolve_expression(if_expr.condition),
consequence: self.resolve_expression(if_expr.consequence),
Expand Down Expand Up @@ -1738,7 +1733,7 @@ mod test {
let (hir_func, _, _) = resolver.resolve_function(func, id);

// Iterate over function statements and apply filtering function
parse_statement_blocks(
find_lambda_captures(
hir_func.block(&interner).statements(),
&interner,
&mut all_captures,
Expand All @@ -1747,33 +1742,23 @@ mod test {
all_captures
}

fn parse_statement_blocks(
fn find_lambda_captures(
stmts: &[StmtId],
interner: &NodeInterner,
result: &mut Vec<Vec<String>>,
) {
let mut expr: HirExpression;

for stmt_id in stmts.iter() {
let hir_stmt = interner.statement(stmt_id);
match hir_stmt {
HirStatement::Expression(expr_id) => {
expr = interner.expression(&expr_id);
}
HirStatement::Let(let_stmt) => {
expr = interner.expression(&let_stmt.expression);
}
HirStatement::Assign(assign_stmt) => {
expr = interner.expression(&assign_stmt.expression);
}
HirStatement::Constrain(constr_stmt) => {
expr = interner.expression(&constr_stmt.0);
}
HirStatement::Semi(semi_expr) => {
expr = interner.expression(&semi_expr);
}
let expr_id = match hir_stmt {
HirStatement::Expression(expr_id) => expr_id,
HirStatement::Let(let_stmt) => let_stmt.expression,
HirStatement::Assign(assign_stmt) => assign_stmt.expression,
HirStatement::Constrain(constr_stmt) => constr_stmt.0,
HirStatement::Semi(semi_expr) => semi_expr,
HirStatement::For(for_loop) => for_loop.block,
HirStatement::Error => panic!("Invalid HirStatement!"),
}
};
let expr = interner.expression(&expr_id);
get_lambda_captures(expr, interner, result); // TODO: dyn filter function as parameter
}
}
Expand All @@ -1794,7 +1779,7 @@ mod test {
// Check for other captures recursively within the lambda body
let hir_body_expr = interner.expression(&lambda_expr.body);
if let HirExpression::Block(block_expr) = hir_body_expr {
parse_statement_blocks(block_expr.statements(), interner, result);
find_lambda_captures(block_expr.statements(), interner, result);
}
}
}
Expand Down
36 changes: 1 addition & 35 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId},
Shared, Signedness, TypeBinding, TypeVariableKind, UnaryOp,
Signedness, TypeBinding, TypeVariableKind, UnaryOp,
};

use super::{errors::TypeCheckError, TypeChecker};
Expand Down Expand Up @@ -194,40 +194,6 @@ impl<'interner> TypeChecker<'interner> {
let span = self.interner.expr_span(expr_id);
self.check_cast(lhs_type, cast_expr.r#type, span)
}
HirExpression::For(for_expr) => {
let start_range_type = self.check_expression(&for_expr.start_range);
let end_range_type = self.check_expression(&for_expr.end_range);

let start_span = self.interner.expr_span(&for_expr.start_range);
let end_span = self.interner.expr_span(&for_expr.end_range);

// Check that start range and end range have the same types
let range_span = start_span.merge(end_span);
self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch {
expected_typ: start_range_type.to_string(),
expr_typ: end_range_type.to_string(),
expr_span: range_span,
});

let fresh_id = self.interner.next_type_variable_id();
let type_variable = Shared::new(TypeBinding::Unbound(fresh_id));
let expected_type =
Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField);

self.unify(&start_range_type, &expected_type, || {
TypeCheckError::TypeCannotBeUsed {
typ: start_range_type.clone(),
place: "for loop",
span: range_span,
}
.add_context("The range of a loop must be known at compile-time")
});

self.interner.push_definition_type(for_expr.identifier.id, start_range_type);

self.check_expression(&for_expr.block);
Type::Unit
}
HirExpression::Block(block_expr) => {
let mut block_type = Type::Unit;

Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ mod test {
fn basic_for_expr() {
let src = r#"
fn main(_x : Field) {
let _j = for _i in 0..10 {
for _i in 0..10 {
for _k in 0..100 {
}
};
}
}
"#;
Expand Down
38 changes: 37 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ use noirc_errors::{Location, Span};

use crate::hir_def::expr::{HirExpression, HirIdent, HirLiteral};
use crate::hir_def::stmt::{
HirAssignStatement, HirConstrainStatement, HirLValue, HirLetStatement, HirPattern, HirStatement,
HirAssignStatement, HirConstrainStatement, HirForStatement, HirLValue, HirLetStatement,
HirPattern, HirStatement,
};
use crate::hir_def::types::Type;
use crate::node_interner::{DefinitionId, ExprId, StmtId};
use crate::{Shared, TypeBinding, TypeVariableKind};

use super::errors::{Source, TypeCheckError};
use super::TypeChecker;
Expand Down Expand Up @@ -48,11 +50,45 @@ impl<'interner> TypeChecker<'interner> {
HirStatement::Let(let_stmt) => self.check_let_stmt(let_stmt),
HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt),
HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id),
HirStatement::For(for_loop) => self.check_for_loop(for_loop),
HirStatement::Error => (),
}
Type::Unit
}

fn check_for_loop(&mut self, for_loop: HirForStatement) {
let start_range_type = self.check_expression(&for_loop.start_range);
let end_range_type = self.check_expression(&for_loop.end_range);

let start_span = self.interner.expr_span(&for_loop.start_range);
let end_span = self.interner.expr_span(&for_loop.end_range);

// Check that start range and end range have the same types
let range_span = start_span.merge(end_span);
self.unify(&start_range_type, &end_range_type, || TypeCheckError::TypeMismatch {
expected_typ: start_range_type.to_string(),
expr_typ: end_range_type.to_string(),
expr_span: range_span,
});

let fresh_id = self.interner.next_type_variable_id();
let type_variable = Shared::new(TypeBinding::Unbound(fresh_id));
let expected_type = Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField);

self.unify(&start_range_type, &expected_type, || {
TypeCheckError::TypeCannotBeUsed {
typ: start_range_type.clone(),
place: "for loop",
span: range_span,
}
.add_context("The range of a loop must be known at compile-time")
});

self.interner.push_definition_type(for_loop.identifier.id, start_range_type);

self.check_expression(&for_loop.block);
}

/// Associate a given HirPattern with the given Type, and remember
/// this association in the NodeInterner.
pub(crate) fn bind_pattern(&mut self, pattern: &HirPattern, typ: Type) {
Expand Down
9 changes: 0 additions & 9 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub enum HirExpression {
Call(HirCallExpression),
MethodCall(HirMethodCallExpression),
Cast(HirCastExpression),
For(HirForExpression),
If(HirIfExpression),
Tuple(Vec<ExprId>),
Lambda(HirLambda),
Expand All @@ -48,14 +47,6 @@ pub struct HirIdent {
pub id: DefinitionId,
}

#[derive(Debug, Clone)]
pub struct HirForExpression {
pub identifier: HirIdent,
pub start_range: ExprId,
pub end_range: ExprId,
pub block: ExprId,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct HirBinaryOp {
pub kind: BinaryOpKind,
Expand Down
Loading

0 comments on commit 0e266eb

Please sign in to comment.