diff --git a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs index f4646c180a3..8cdf6f5502c 100644 --- a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs +++ b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -189,20 +189,27 @@ impl Type { op: BinaryTypeOperator, rhs: &Type, ) -> Option { - let Type::InfixExpr(l_lhs, l_op, l_rhs) = lhs.follow_bindings() else { - return None; - }; + match lhs.follow_bindings() { + Type::CheckedCast { from, to } => { + // Apply operation directly to `from` while attempting simplification to `to`. + let from = Type::InfixExpr(from, op, Box::new(rhs.clone())); + let to = Self::try_simplify_non_constants_in_lhs(&to, op, rhs)?; + Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) }) + } + Type::InfixExpr(l_lhs, l_op, l_rhs) => { + // Note that this is exact, syntactic equality, not unification. + // `rhs` is expected to already be in canonical form. + if l_op.approx_inverse() != Some(op) + || l_op == BinaryTypeOperator::Division + || l_rhs.canonicalize_unchecked() != *rhs + { + return None; + } - // Note that this is exact, syntactic equality, not unification. - // `rhs` is expected to already be in canonical form. - if l_op.approx_inverse() != Some(op) - || l_op == BinaryTypeOperator::Division - || l_rhs.canonicalize_unchecked() != *rhs - { - return None; + Some(*l_lhs) + } + _ => None, } - - Some(*l_lhs) } /// Try to simplify non-constant expressions in the form `N op1 (M op1 N)` @@ -219,23 +226,31 @@ impl Type { op: BinaryTypeOperator, rhs: &Type, ) -> Option { - let Type::InfixExpr(r_lhs, r_op, r_rhs) = rhs.follow_bindings() else { - return None; - }; + match rhs.follow_bindings() { + Type::CheckedCast { from, to } => { + // Apply operation directly to `from` while attempting simplification to `to`. + let from = Type::InfixExpr(Box::new(lhs.clone()), op, from); + let to = Self::try_simplify_non_constants_in_rhs(lhs, op, &to)?; + Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) }) + } + Type::InfixExpr(r_lhs, r_op, r_rhs) => { + // `N / (M * N)` should be simplified to `1 / M`, but we only handle + // simplifying to `M` in this function. + if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication + { + return None; + } - // `N / (M * N)` should be simplified to `1 / M`, but we only handle - // simplifying to `M` in this function. - if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication { - return None; - } + // Note that this is exact, syntactic equality, not unification. + // `lhs` is expected to already be in canonical form. + if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize_unchecked() { + return None; + } - // Note that this is exact, syntactic equality, not unification. - // `lhs` is expected to already be in canonical form. - if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize_unchecked() { - return None; + Some(*r_lhs) + } + _ => None, } - - Some(*r_lhs) } /// Given: @@ -360,6 +375,47 @@ mod tests { use crate::hir_def::types::{BinaryTypeOperator, Kind, Type, TypeVariable, TypeVariableId}; + #[test] + fn solves_n_minus_one_plus_one_through_checked_casts() { + // We want to test that the inclusion of a `CheckedCast` won't prevent us from canonicalizing + // the expression `(N - 1) + 1` to `N` if there exists a `CheckedCast` on the `N - 1` term. + + let n = Type::NamedGeneric( + TypeVariable::unbound(TypeVariableId(0), Kind::u32()), + std::rc::Rc::new("N".to_owned()), + ); + let n_minus_one = Type::InfixExpr( + Box::new(n.clone()), + BinaryTypeOperator::Subtraction, + Box::new(Type::Constant(FieldElement::one(), Kind::u32())), + ); + let checked_cast_n_minus_one = + Type::CheckedCast { from: Box::new(n_minus_one.clone()), to: Box::new(n_minus_one) }; + + let n_minus_one_plus_one = Type::InfixExpr( + Box::new(checked_cast_n_minus_one.clone()), + BinaryTypeOperator::Addition, + Box::new(Type::Constant(FieldElement::one(), Kind::u32())), + ); + + let canonicalized_typ = n_minus_one_plus_one.canonicalize(); + + assert_eq!(n, canonicalized_typ); + + // We also want to check that if the `CheckedCast` is on the RHS then we'll still be able to canonicalize + // the expression `1 + (N - 1)` to `N`. + + let one_plus_n_minus_one = Type::InfixExpr( + Box::new(Type::Constant(FieldElement::one(), Kind::u32())), + BinaryTypeOperator::Addition, + Box::new(checked_cast_n_minus_one), + ); + + let canonicalized_typ = one_plus_n_minus_one.canonicalize(); + + assert_eq!(n, canonicalized_typ); + } + #[test] fn instantiate_after_canonicalize_smoke_test() { let field_element_kind = Kind::numeric(Type::FieldElement); diff --git a/compiler/noirc_frontend/src/tests/arithmetic_generics.rs b/compiler/noirc_frontend/src/tests/arithmetic_generics.rs index 3fa41b80279..3328fe439ae 100644 --- a/compiler/noirc_frontend/src/tests/arithmetic_generics.rs +++ b/compiler/noirc_frontend/src/tests/arithmetic_generics.rs @@ -27,6 +27,36 @@ fn arithmetic_generics_canonicalization_deduplication_regression() { assert_eq!(errors.len(), 0); } +#[test] +fn checked_casts_do_not_prevent_canonicalization() { + // Regression test for https://github.com/noir-lang/noir/issues/6495 + let source = r#" + pub trait Serialize { + fn serialize(self) -> [Field; N]; + } + + pub struct Counted { + pub inner: T, + } + + pub fn append(array1: [T; N]) -> [T; N + 1] { + [array1[0]; N + 1] + } + + impl Serialize for Counted + where + T: Serialize, + { + fn serialize(self) -> [Field; N] { + append(self.inner.serialize()) + } + } + "#; + let errors = get_program_errors(source); + println!("{:?}", errors); + assert_eq!(errors.len(), 0); +} + #[test] fn arithmetic_generics_checked_cast_zeros() { let source = r#" diff --git a/cspell.json b/cspell.json index aa0f749bd9c..d8315c35910 100644 --- a/cspell.json +++ b/cspell.json @@ -38,8 +38,10 @@ "callsites", "callstack", "callstacks", + "canonicalization", "canonicalize", "canonicalized", + "canonicalizing", "castable", "catmcgee", "Celo",