Skip to content

Commit

Permalink
fix: perform arithmetic simplification through CheckedCast (#6502)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Nov 12, 2024
1 parent 21c9db5 commit 72e8de0
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 26 deletions.
108 changes: 82 additions & 26 deletions compiler/noirc_frontend/src/hir_def/types/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,27 @@ impl Type {
op: BinaryTypeOperator,
rhs: &Type,
) -> Option<Type> {
let Type::InfixExpr(l_lhs, l_op, l_rhs) = lhs.follow_bindings() else {
return None;
};
match lhs.follow_bindings() {
Type::CheckedCast { from, to } => {
// Apply operation directly to `from` while attempting simplification to `to`.
let from = Type::InfixExpr(from, op, Box::new(rhs.clone()));
let to = Self::try_simplify_non_constants_in_lhs(&to, op, rhs)?;
Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) })
}
Type::InfixExpr(l_lhs, l_op, l_rhs) => {
// Note that this is exact, syntactic equality, not unification.
// `rhs` is expected to already be in canonical form.
if l_op.approx_inverse() != Some(op)
|| l_op == BinaryTypeOperator::Division
|| l_rhs.canonicalize_unchecked() != *rhs
{
return None;
}

// Note that this is exact, syntactic equality, not unification.
// `rhs` is expected to already be in canonical form.
if l_op.approx_inverse() != Some(op)
|| l_op == BinaryTypeOperator::Division
|| l_rhs.canonicalize_unchecked() != *rhs
{
return None;
Some(*l_lhs)
}
_ => None,
}

Some(*l_lhs)
}

/// Try to simplify non-constant expressions in the form `N op1 (M op1 N)`
Expand All @@ -219,23 +226,31 @@ impl Type {
op: BinaryTypeOperator,
rhs: &Type,
) -> Option<Type> {
let Type::InfixExpr(r_lhs, r_op, r_rhs) = rhs.follow_bindings() else {
return None;
};
match rhs.follow_bindings() {
Type::CheckedCast { from, to } => {
// Apply operation directly to `from` while attempting simplification to `to`.
let from = Type::InfixExpr(Box::new(lhs.clone()), op, from);
let to = Self::try_simplify_non_constants_in_rhs(lhs, op, &to)?;
Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) })
}
Type::InfixExpr(r_lhs, r_op, r_rhs) => {
// `N / (M * N)` should be simplified to `1 / M`, but we only handle
// simplifying to `M` in this function.
if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication
{
return None;
}

// `N / (M * N)` should be simplified to `1 / M`, but we only handle
// simplifying to `M` in this function.
if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication {
return None;
}
// Note that this is exact, syntactic equality, not unification.
// `lhs` is expected to already be in canonical form.
if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize_unchecked() {
return None;
}

// Note that this is exact, syntactic equality, not unification.
// `lhs` is expected to already be in canonical form.
if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize_unchecked() {
return None;
Some(*r_lhs)
}
_ => None,
}

Some(*r_lhs)
}

/// Given:
Expand Down Expand Up @@ -360,6 +375,47 @@ mod tests {

use crate::hir_def::types::{BinaryTypeOperator, Kind, Type, TypeVariable, TypeVariableId};

#[test]
fn solves_n_minus_one_plus_one_through_checked_casts() {
// We want to test that the inclusion of a `CheckedCast` won't prevent us from canonicalizing
// the expression `(N - 1) + 1` to `N` if there exists a `CheckedCast` on the `N - 1` term.

let n = Type::NamedGeneric(
TypeVariable::unbound(TypeVariableId(0), Kind::u32()),
std::rc::Rc::new("N".to_owned()),
);
let n_minus_one = Type::InfixExpr(
Box::new(n.clone()),
BinaryTypeOperator::Subtraction,
Box::new(Type::Constant(FieldElement::one(), Kind::u32())),
);
let checked_cast_n_minus_one =
Type::CheckedCast { from: Box::new(n_minus_one.clone()), to: Box::new(n_minus_one) };

let n_minus_one_plus_one = Type::InfixExpr(
Box::new(checked_cast_n_minus_one.clone()),
BinaryTypeOperator::Addition,
Box::new(Type::Constant(FieldElement::one(), Kind::u32())),
);

let canonicalized_typ = n_minus_one_plus_one.canonicalize();

assert_eq!(n, canonicalized_typ);

// We also want to check that if the `CheckedCast` is on the RHS then we'll still be able to canonicalize
// the expression `1 + (N - 1)` to `N`.

let one_plus_n_minus_one = Type::InfixExpr(
Box::new(Type::Constant(FieldElement::one(), Kind::u32())),
BinaryTypeOperator::Addition,
Box::new(checked_cast_n_minus_one),
);

let canonicalized_typ = one_plus_n_minus_one.canonicalize();

assert_eq!(n, canonicalized_typ);
}

#[test]
fn instantiate_after_canonicalize_smoke_test() {
let field_element_kind = Kind::numeric(Type::FieldElement);
Expand Down
30 changes: 30 additions & 0 deletions compiler/noirc_frontend/src/tests/arithmetic_generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,36 @@ fn arithmetic_generics_canonicalization_deduplication_regression() {
assert_eq!(errors.len(), 0);
}

#[test]
fn checked_casts_do_not_prevent_canonicalization() {
// Regression test for https://github.com/noir-lang/noir/issues/6495
let source = r#"
pub trait Serialize<let N: u32> {
fn serialize(self) -> [Field; N];
}
pub struct Counted<T> {
pub inner: T,
}
pub fn append<T, let N: u32>(array1: [T; N]) -> [T; N + 1] {
[array1[0]; N + 1]
}
impl<T, let N: u32> Serialize<N> for Counted<T>
where
T: Serialize<N - 1>,
{
fn serialize(self) -> [Field; N] {
append(self.inner.serialize())
}
}
"#;
let errors = get_program_errors(source);
println!("{:?}", errors);
assert_eq!(errors.len(), 0);
}

#[test]
fn arithmetic_generics_checked_cast_zeros() {
let source = r#"
Expand Down
2 changes: 2 additions & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
"callsites",
"callstack",
"callstacks",
"canonicalization",
"canonicalize",
"canonicalized",
"canonicalizing",
"castable",
"catmcgee",
"Celo",
Expand Down

0 comments on commit 72e8de0

Please sign in to comment.