Skip to content

Commit

Permalink
Merge 909873a into 57afc7d
Browse files Browse the repository at this point in the history
  • Loading branch information
jfecher authored Sep 13, 2024
2 parents 57afc7d + 909873a commit 2d4d971
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
22 changes: 15 additions & 7 deletions compiler/noirc_frontend/src/hir_def/types/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::BTreeSet;
use std::collections::BTreeMap;

use crate::{BinaryTypeOperator, Type, TypeBindings, UnificationError};

Expand Down Expand Up @@ -52,7 +52,8 @@ impl Type {
fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type {
let mut queue = vec![lhs.clone(), rhs.clone()];

let mut sorted = BTreeSet::new();
// Maps each term to the number of times that term was used.
let mut sorted = BTreeMap::new();

let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 };
let mut constant = zero_value;
Expand All @@ -68,20 +69,27 @@ impl Type {
if let Some(result) = op.function(constant, new_constant) {
constant = result;
} else {
sorted.insert(Type::Constant(new_constant));
*sorted.entry(Type::Constant(new_constant)).or_default() += 1;
}
}
other => {
sorted.insert(other);
*sorted.entry(other).or_default() += 1;
}
}
}

if let Some(first) = sorted.pop_first() {
let mut typ = first.clone();
let (mut typ, first_type_count) = first.clone();

for rhs in sorted {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone()));
// - 1 since `typ` already is set to the first instance
for _ in 0..first_type_count - 1 {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(first.0.clone()));
}

for (rhs, rhs_count) in sorted {
for _ in 0..rhs_count {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone()));
}
}

if constant != zero_value {
Expand Down
19 changes: 19 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3463,6 +3463,25 @@ fn comptime_type_in_runtime_code() {
));
}

#[test]
fn arithmetic_generics_canonicalization_deduplication_regression() {
let source = r#"
struct ArrData<let N: u32> {
a: [Field; N],
b: [Field; N + N - 1],
}
fn main() {
let _f: ArrData<5> = ArrData {
a: [0; 5],
b: [0; 9],
};
}
"#;
let errors = get_program_errors(source);
assert_eq!(errors.len(), 0);
}

#[test]
fn cannot_mutate_immutable_variable() {
let src = r#"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,11 @@ fn demo_proof<let N: u32>() -> Equiv<W<(N * (N + 1))>, (Equiv<W<N>, (), W<N>, ()
let p1: Equiv<W<(N + 1) * N>, (), W<N * (N + 1)>, ()> = mul_comm();
let p2: Equiv<W<N * (N + 1)>, (), W<N * N + N>, ()> = mul_add::<N, N, 1>();
let p3_sub: Equiv<W<N>, (), W<N>, ()> = mul_one_r();
let p3: Equiv<W<N * N + N>, (), W<N * N + N>, ()> = add_equiv_r::<N * N, N, N, _, _>(p3_sub);
equiv_trans(equiv_trans(p1, p2), p3)
let _p3: Equiv<W<N * N + N>, (), W<N * N + N>, ()> = add_equiv_r::<N * N, N, N, _, _>(p3_sub);
let _p1_to_2 = equiv_trans(p1, p2);

// equiv_trans(p1_to_2, p3)
std::mem::zeroed()
}

fn test_constant_folding<let N: u32>() {
Expand Down

0 comments on commit 2d4d971

Please sign in to comment.