From 1cc9b8a59aa387ffb97e3a9bef809833d808f755 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 28 Jun 2024 16:15:02 +0200 Subject: [PATCH] TDim::mini, TDim::maxi --- core/src/ops/math/mod.rs | 2 + data/src/dim/mod.rs | 26 ++++++++ data/src/dim/tree.rs | 130 ++++++++++++++++++++++++++++++++++--- nnef/src/ops/nnef/deser.rs | 5 ++ nnef/src/ser.rs | 1 + 5 files changed, 155 insertions(+), 9 deletions(-) diff --git a/core/src/ops/math/mod.rs b/core/src/ops/math/mod.rs index c1f78211cf..bb4e2acaeb 100644 --- a/core/src/ops/math/mod.rs +++ b/core/src/ops/math/mod.rs @@ -225,6 +225,7 @@ bin_to_super_type!(min, Min, linalg:Min, q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b }; q_op_on_f32: |a: f32, b: f32| a.min(b), [f16, f32, f64] => |c,a,b| *c = a.min(*b), + [TDim] => |c,a,b| *c = a.clone().mini(b.clone()), [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.min(b)); bin_to_super_type!(max, Max, @@ -272,6 +273,7 @@ bin_to_super_type!(max, Max, q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a }; q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)}, [f16, f32, f64] => |c,a,b| *c = a.max(*b), + [TDim] => |c,a,b| *c = a.clone().maxi(b.clone()), [i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.max(b)); bin_to_super_type!(pow, Pow, diff --git a/data/src/dim/mod.rs b/data/src/dim/mod.rs index c8255bafd5..8588cf30db 100644 --- a/data/src/dim/mod.rs +++ b/data/src/dim/mod.rs @@ -84,6 +84,8 @@ pub trait DimLike: fn substitute(&self, from: &Symbol, to: &Self) -> TractResult; fn broadcast(self, other: Self) -> TractResult; + fn mini(self, other: Self) -> Self; + fn maxi(self, other: Self) -> Self; fn compatible_with(&self, other: &Self) -> bool; } @@ -167,6 +169,14 @@ impl DimLike for TDim { fn compatible_with(&self, other: &Self) -> bool { self.compatible_with(other) } + + fn mini(self, other: Self) -> Self { + TDim::Min(vec![self, other]).simplify() + } + + fn maxi(self, other: Self) -> Self { + TDim::Min(vec![self, other]).simplify() + } } impl<'a> std::convert::TryFrom<&'a TDim> for TDim { @@ -216,6 +226,22 @@ impl DimLike for usize { fn compatible_with(&self, other: &Self) -> bool { self == other } + + fn mini(self, other: Self) -> Self { + if self < other { + self + } else { + other + } + } + + fn maxi(self, other: Self) -> Self { + if self > other { + self + } else { + other + } + } } impl<'a> std::convert::TryFrom<&'a TDim> for usize { diff --git a/data/src/dim/tree.rs b/data/src/dim/tree.rs index 585e7df23f..6a103fb251 100644 --- a/data/src/dim/tree.rs +++ b/data/src/dim/tree.rs @@ -30,6 +30,8 @@ pub enum TDim { MulInt(i64, Box), Div(Box, u64), Broadcast(Vec), + Min(Vec), + Max(Vec), } use TDim::*; @@ -38,13 +40,15 @@ fn tdim_compare(a: &TDim, b: &TDim) -> Ordering { match (a, b) { (Sym(a), Sym(b)) => a.cmp(b), (Val(a), Val(b)) => a.cmp(b), - (Add(a), Add(b)) | (Mul(a), Mul(b)) | (Broadcast(a), Broadcast(b)) => { - a.len().cmp(&b.len()).then( - a.iter() - .zip(b.iter()) - .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_compare(a, b))), - ) - } + (Add(a), Add(b)) + | (Mul(a), Mul(b)) + | (Broadcast(a), Broadcast(b)) + | (Min(a), Min(b)) + | (Max(a), Max(b)) => a.len().cmp(&b.len()).then( + a.iter() + .zip(b.iter()) + .fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_compare(a, b))), + ), (MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_compare(d, e)), (Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_compare(d, e)), (Sym(_), _) => Ordering::Less, @@ -59,6 +63,10 @@ fn tdim_compare(a: &TDim, b: &TDim) -> Ordering { (_, MulInt(_, _)) => Ordering::Greater, (Broadcast(_), _) => Ordering::Less, (_, Broadcast(_)) => Ordering::Greater, + (Min(_), _) => Ordering::Less, + (_, Min(_)) => Ordering::Greater, + (Max(_), _) => Ordering::Less, + (_, Max(_)) => Ordering::Greater, } } @@ -70,6 +78,8 @@ impl fmt::Display for TDim { Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")), Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")), Broadcast(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("#")), + Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")), + Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")), MulInt(a, b) => write!(fmt, "{a}*{b}"), Div(a, b) => write!(fmt, "({a})/{b}"), } @@ -113,6 +123,12 @@ impl TDim { Mul(terms) => { terms.iter().try_fold(1, |acc, it| it.eval_to_i64(values).map(|x| acc * x)) } + Min(terms) => terms + .iter() + .try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))), + Max(terms) => terms + .iter() + .try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))), Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| { it.eval_to_i64(values) .and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64)) @@ -128,6 +144,12 @@ impl TDim { Val(v) => Val(*v), Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }), Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }), + Min(terms) => { + terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) }) + } + Max(terms) => { + terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) }) + } Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone()) }), @@ -149,6 +171,12 @@ impl TDim { Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult { acc.broadcast(it.substitute(from, to)?) }), + Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult { + Ok(acc.mini(it.substitute(from, to)?)) + }), + Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult { + Ok(acc.maxi(it.substitute(from, to)?)) + }), Div(a, q) => Ok(a.substitute(from, to)? / *q as i64), MulInt(p, a) => Ok(a.substitute(from, to)? * *p), } @@ -172,6 +200,7 @@ impl TDim { Add(terms) => 2 * terms.iter().map(TDim::cost).sum::(), Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::(), Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::(), + Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::(), Div(a, _) => 3 * a.cost(), MulInt(_, a) => 2 * a.cost(), } @@ -180,7 +209,7 @@ impl TDim { fn wiggle(&self) -> Vec { use self::TDim::*; match self { - Sym(_) | Val(_) | Mul(_) | Broadcast(_) => vec![self.clone()], + Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) => vec![self.clone()], Add(terms) => { let mut forms = vec![]; let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product(); @@ -419,6 +448,58 @@ impl TDim { Broadcast(terms) } } + Min(terms) => { + let flatten: Vec = terms + .into_iter() + .map(TDim::simplify) + .flat_map(|t| if let Min(t) = t { t } else { vec![t] }) + .sorted_by(tdim_compare) + .dedup() + .collect(); + let mut new_terms: Vec = flatten + .iter() + .filter(|&t| { + t != &i64::MAX.to_dim() + && !flatten + .iter() + .any(|other| (t.clone() - other).to_i64().is_ok_and(|i| i > 0)) + }) + .cloned() + .collect(); + if new_terms.len() == 0 { + i64::MAX.to_dim() + } else if new_terms.len() == 1 { + new_terms.remove(0) + } else { + Min(new_terms) + } + } + Max(terms) => { + let flatten: Vec = terms + .into_iter() + .map(TDim::simplify) + .flat_map(|t| if let Max(t) = t { t } else { vec![t] }) + .sorted_by(tdim_compare) + .dedup() + .collect(); + let mut new_terms: Vec = flatten + .iter() + .filter(|&t| { + t != &i64::MIN.to_dim() + && !flatten + .iter() + .any(|other| (t.clone() - other).to_i64().is_ok_and(|i| i < 0)) + }) + .cloned() + .collect(); + if new_terms.len() == 0 { + i64::MIN.to_dim() + } else if new_terms.len() == 1 { + new_terms.remove(0) + } else { + Max(new_terms) + } + } Val(_) | Sym(_) => self, } } @@ -435,6 +516,8 @@ impl TDim { } MulInt(p, a) => a.gcd() * p.unsigned_abs(), Mul(terms) => terms.iter().map(|t| t.gcd()).product(), + Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(), + Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(), Div(a, q) => { if a.gcd() % *q == 0 { a.gcd() / *q @@ -456,6 +539,8 @@ impl TDim { Val(v) => Val(v / d as i64), Sym(_) => panic!(), Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()), + Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()), + Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()), Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()), Mul(_) => Div(Box::new(self.clone()), d), MulInt(p, a) => { @@ -496,6 +581,8 @@ impl TDim { (n, d * *q as i64) } Broadcast(terms) => slope_rec(&terms[0], sym), + Min(terms) => slope_rec(&terms[0], sym), + Max(terms) => slope_rec(&terms[0], sym), } } let (p, q) = slope_rec(self, sym); @@ -507,7 +594,7 @@ impl TDim { match self { Val(_) => maplit::hashset!(), Sym(s) => maplit::hashset!(s.clone()), - Add(terms) | Mul(terms) | Broadcast(terms) => { + Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => { terms.iter().fold(maplit::hashset!(), |mut set, v| { set.extend(v.symbols()); set @@ -995,4 +1082,29 @@ mod tests { let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3); assert_eq!(mul1, mul2); } + + #[test] + fn min_ints_1() { + assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim()); + } + + #[test] + fn min_ints_2() { + assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim()); + } + + #[test] + fn min_same() { + assert_eq!(s().mini(s()), s()); + } + + #[test] + fn min_noop() { + assert_eq!(s().mini(1.to_dim()), s().mini(1.to_dim())); + } + + #[test] + fn min_diff_1() { + assert_eq!((s() + 1).mini(s() + 2), s() + 1); + } } diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index 5ab425584f..a0731bde2b 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -9,6 +9,7 @@ use tract_core::ops::cnn::deconv::adjustments; use tract_core::ops::cnn::PaddingSpec; use tract_core::ops::cnn::PoolSpec; use tract_core::ops::konst::Const; +use tract_core::ops::math::min; use tract_core::ops::nn::{DataFormat, Softmax, SoftmaxExp}; use tract_itertools::Itertools; @@ -159,11 +160,14 @@ pub fn slice(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tra let strides: TVec = invocation.named_arg_as(builder, "stride").unwrap_or_else(|_| tvec!(1; axes.len())); for (ix, axis) in axes.into_iter().enumerate() { + let axis_len = + builder.wire_as_outlets(Const(rctensor0(input_fact.shape[axis].clone())), &[])?[0]; let b = builder.wire_as_outlets( tract_core::ops::array::Slice { axis: 0, start: ix.into(), end: ix.to_dim() + 1 }, &[begins], )?; let mut b = builder.wire_as_outlets(tract_core::ops::change_axes::AxisOp::Rm(0), &b)?; + b = builder.wire_as_outlets(min(), &[b[0], axis_len])?; if let Some(k) = &builder.model.outlet_fact(b[0])?.konst { if let Ok(i) = k.cast_to_scalar::() { if i < 0 { @@ -179,6 +183,7 @@ pub fn slice(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tra &[ends], )?; let mut e = builder.wire_as_outlets(tract_core::ops::change_axes::AxisOp::Rm(0), &e)?; + e = builder.wire_as_outlets(min(), &[e[0], axis_len])?; // use "<=", no "<" end[axis] = 0 means "up to the end" // CAUTION: this notation is 1/ deprecated 2/ invalid with non trivial slicing if let Some(k) = &builder.model.outlet_fact(e[0])?.konst { diff --git a/nnef/src/ser.rs b/nnef/src/ser.rs index 09128e059f..2bbd845462 100644 --- a/nnef/src/ser.rs +++ b/nnef/src/ser.rs @@ -448,6 +448,7 @@ pub fn tdim(dim: &TDim) -> RValue { TDim::MulInt(x, y) => RValue::Binary(numeric(x).boxed(), "*".to_string(), tdim(y).boxed()), TDim::Div(x, y) => RValue::Binary(tdim(x).boxed(), "/".to_string(), numeric(y).boxed()), TDim::Broadcast(_) => todo!(), + TDim::Min(_) | TDim::Max(_) => todo!(), } }