From fe4e606e4bd8663b62d9934650457ab887eb223d Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 19 Apr 2024 11:20:00 +0200 Subject: [PATCH] fix: draft to simplify TDim a*(b+c) to a*b+b*c --- data/src/dim/tree.rs | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 85a8801a53..0aa7734ef8 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -288,7 +288,20 @@ impl TDim { (_, 0) => Val(coef_prod), // Case #1: If 0 variables, return product (0, _) => Val(0), // Case #2: Result is 0 if coef is 0 (1, 1) => vars.remove(0), // Case #3: Product is 1, so return the only term - (1, _) => Mul(vars), // Case #4: Product is 1, so return the non-integer terms + (1, _) => { + if let Some(position) = vars.iter().position(|a| matches!(a, Add(_))) { + let add = vars.remove(position); + let Add(add) = add else { unreachable!() }; + vars.iter() + .cartesian_product(add) + .map(|(a, b)| a.clone() * b) + .sum::() + .simplify() + } else { + Mul(vars.into_iter().sorted_by(tdim_compare).collect()) + // Case #4: Product is 1, so return the non-integer terms + } + } (_, 1) => MulInt(coef_prod, Box::new(vars.remove(0))), // Case #5: Single variable, convert to 1 MulInt _ => MulInt(coef_prod, Box::new(Mul(vars))), // Case #6: Multiple variables, convert to MulInt } @@ -723,6 +736,8 @@ impl + PrimInt> ops::Rem for TDim { #[cfg(test)] mod tests { + use crate::prelude::ToDim; + use super::*; macro_rules! b( ($e:expr) => { Box::new($e) } ); @@ -927,6 +942,32 @@ mod tests { assert_eq!(e, TDim::from(0)); } + #[test] + fn reduce_distribute_simple() { + let a = S.0.sym("a").to_dim(); + let c = S.0.sym("c").to_dim(); + let d = S.0.sym("d").to_dim(); + + let e: TDim = a.clone() * (c.clone() + d.clone()); + let f: TDim = a.clone() * c.clone() + a.clone() * d.clone(); + assert_eq!(e, f); + } + + #[test] + fn reduce_distribute() { + let a = S.0.sym("a").to_dim(); + let b = S.0.sym("b").to_dim(); + let c = S.0.sym("c").to_dim(); + let d = S.0.sym("d").to_dim(); + + let e: TDim = (a.clone() + b.clone()) * (c.clone() + d.clone()); + let f: TDim = (a.clone() * c.clone()) + + a.clone() * d.clone() + + b.clone() * c.clone() + + b.clone() * d.clone(); + assert_eq!(e, f); + } + #[test] fn conv2d_ex_1() { let e = (TDim::from(1) - 1 + 1).div_ceil(1);