Skip to content

Commit

Permalink
TDim::mini, TDim::maxi
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jun 28, 2024
1 parent 664cfc1 commit 1cc9b8a
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 9 deletions.
2 changes: 2 additions & 0 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ bin_to_super_type!(min, Min, linalg:Min,
q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *a } else { *b };
q_op_on_f32: |a: f32, b: f32| a.min(b),
[f16, f32, f64] => |c,a,b| *c = a.min(*b),
[TDim] => |c,a,b| *c = a.clone().mini(b.clone()),
[i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.min(b));

bin_to_super_type!(max, Max,
Expand Down Expand Up @@ -272,6 +273,7 @@ bin_to_super_type!(max, Max,
q: [i8, u8, i32] => |c, a, b, _, _| *c = if a < b { *b } else { *a };
q_op_on_f32: |a: f32, b: f32| -> f32 {a.max(b)},
[f16, f32, f64] => |c,a,b| *c = a.max(*b),
[TDim] => |c,a,b| *c = a.clone().maxi(b.clone()),
[i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a.max(b));

bin_to_super_type!(pow, Pow,
Expand Down
26 changes: 26 additions & 0 deletions data/src/dim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ pub trait DimLike:
fn substitute(&self, from: &Symbol, to: &Self) -> TractResult<Self>;

fn broadcast(self, other: Self) -> TractResult<Self>;
fn mini(self, other: Self) -> Self;
fn maxi(self, other: Self) -> Self;

fn compatible_with(&self, other: &Self) -> bool;
}
Expand Down Expand Up @@ -167,6 +169,14 @@ impl DimLike for TDim {
fn compatible_with(&self, other: &Self) -> bool {
self.compatible_with(other)
}

fn mini(self, other: Self) -> Self {
TDim::Min(vec![self, other]).simplify()
}

fn maxi(self, other: Self) -> Self {
TDim::Min(vec![self, other]).simplify()
}
}

impl<'a> std::convert::TryFrom<&'a TDim> for TDim {
Expand Down Expand Up @@ -216,6 +226,22 @@ impl DimLike for usize {
fn compatible_with(&self, other: &Self) -> bool {
self == other
}

fn mini(self, other: Self) -> Self {
if self < other {
self
} else {
other
}
}

fn maxi(self, other: Self) -> Self {
if self > other {
self
} else {
other
}
}
}

impl<'a> std::convert::TryFrom<&'a TDim> for usize {
Expand Down
130 changes: 121 additions & 9 deletions data/src/dim/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub enum TDim {
MulInt(i64, Box<TDim>),
Div(Box<TDim>, u64),
Broadcast(Vec<TDim>),
Min(Vec<TDim>),
Max(Vec<TDim>),
}

use TDim::*;
Expand All @@ -38,13 +40,15 @@ fn tdim_compare(a: &TDim, b: &TDim) -> Ordering {
match (a, b) {
(Sym(a), Sym(b)) => a.cmp(b),
(Val(a), Val(b)) => a.cmp(b),
(Add(a), Add(b)) | (Mul(a), Mul(b)) | (Broadcast(a), Broadcast(b)) => {
a.len().cmp(&b.len()).then(
a.iter()
.zip(b.iter())
.fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_compare(a, b))),
)
}
(Add(a), Add(b))
| (Mul(a), Mul(b))
| (Broadcast(a), Broadcast(b))
| (Min(a), Min(b))
| (Max(a), Max(b)) => a.len().cmp(&b.len()).then(
a.iter()
.zip(b.iter())
.fold(Ordering::Equal, |acc, (a, b)| acc.then_with(|| tdim_compare(a, b))),
),
(MulInt(p, d), MulInt(q, e)) => p.cmp(q).then_with(|| tdim_compare(d, e)),
(Div(d, p), Div(e, q)) => p.cmp(q).then_with(|| tdim_compare(d, e)),
(Sym(_), _) => Ordering::Less,
Expand All @@ -59,6 +63,10 @@ fn tdim_compare(a: &TDim, b: &TDim) -> Ordering {
(_, MulInt(_, _)) => Ordering::Greater,
(Broadcast(_), _) => Ordering::Less,
(_, Broadcast(_)) => Ordering::Greater,
(Min(_), _) => Ordering::Less,
(_, Min(_)) => Ordering::Greater,
(Max(_), _) => Ordering::Less,
(_, Max(_)) => Ordering::Greater,
}
}

Expand All @@ -70,6 +78,8 @@ impl fmt::Display for TDim {
Add(it) => write!(fmt, "{}", it.iter().map(|x| format!("{x}")).join("+")),
Mul(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("*")),
Broadcast(it) => write!(fmt, "{}", it.iter().map(|x| format!("({x})")).join("#")),
Min(it) => write!(fmt, "min({})", it.iter().map(|x| format!("{x}")).join(",")),
Max(it) => write!(fmt, "max({})", it.iter().map(|x| format!("{x}")).join(",")),
MulInt(a, b) => write!(fmt, "{a}*{b}"),
Div(a, b) => write!(fmt, "({a})/{b}"),
}
Expand Down Expand Up @@ -113,6 +123,12 @@ impl TDim {
Mul(terms) => {
terms.iter().try_fold(1, |acc, it| it.eval_to_i64(values).map(|x| acc * x))
}
Min(terms) => terms
.iter()
.try_fold(i64::MAX, |acc, it| it.eval_to_i64(values).map(|x| acc.min(x))),
Max(terms) => terms
.iter()
.try_fold(i64::MIN, |acc, it| it.eval_to_i64(values).map(|x| acc.max(x))),
Broadcast(terms) => terms.iter().try_fold(1i64, |acc, it| {
it.eval_to_i64(values)
.and_then(|x| ((acc as usize).broadcast(x as usize)).map(|x| x as i64))
Expand All @@ -128,6 +144,12 @@ impl TDim {
Val(v) => Val(*v),
Add(terms) => terms.iter().fold(Val(0), |acc, it| -> TDim { acc + it.eval(values) }),
Mul(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim { acc * it.eval(values) }),
Min(terms) => {
terms.iter().fold(Val(i64::MAX), |acc, it| -> TDim { acc.mini(it.eval(values)) })
}
Max(terms) => {
terms.iter().fold(Val(i64::MIN), |acc, it| -> TDim { acc.maxi(it.eval(values)) })
}
Broadcast(terms) => terms.iter().fold(Val(1), |acc, it| -> TDim {
acc.broadcast(it.eval(values)).unwrap_or_else(|_| self.clone())
}),
Expand All @@ -149,6 +171,12 @@ impl TDim {
Broadcast(terms) => terms.iter().try_fold(Val(1), |acc, it| -> TractResult<TDim> {
acc.broadcast(it.substitute(from, to)?)
}),
Min(terms) => terms.iter().try_fold(Val(i64::MAX), |acc, it| -> TractResult<TDim> {
Ok(acc.mini(it.substitute(from, to)?))
}),
Max(terms) => terms.iter().try_fold(Val(i64::MIN), |acc, it| -> TractResult<TDim> {
Ok(acc.maxi(it.substitute(from, to)?))
}),
Div(a, q) => Ok(a.substitute(from, to)? / *q as i64),
MulInt(p, a) => Ok(a.substitute(from, to)? * *p),
}
Expand All @@ -172,6 +200,7 @@ impl TDim {
Add(terms) => 2 * terms.iter().map(TDim::cost).sum::<usize>(),
Mul(terms) => 3 * terms.iter().map(TDim::cost).sum::<usize>(),
Broadcast(terms) => 4 * terms.iter().map(TDim::cost).sum::<usize>(),
Min(terms) | Max(terms) => 5 * terms.iter().map(TDim::cost).sum::<usize>(),
Div(a, _) => 3 * a.cost(),
MulInt(_, a) => 2 * a.cost(),
}
Expand All @@ -180,7 +209,7 @@ impl TDim {
fn wiggle(&self) -> Vec<TDim> {
use self::TDim::*;
match self {
Sym(_) | Val(_) | Mul(_) | Broadcast(_) => vec![self.clone()],
Sym(_) | Val(_) | Mul(_) | Broadcast(_) | Min(_) | Max(_) => vec![self.clone()],
Add(terms) => {
let mut forms = vec![];
let sub_exprs = terms.iter().map(|e| e.wiggle()).multi_cartesian_product();
Expand Down Expand Up @@ -419,6 +448,58 @@ impl TDim {
Broadcast(terms)
}
}
Min(terms) => {
let flatten: Vec<TDim> = terms
.into_iter()
.map(TDim::simplify)
.flat_map(|t| if let Min(t) = t { t } else { vec![t] })
.sorted_by(tdim_compare)
.dedup()
.collect();
let mut new_terms: Vec<TDim> = flatten
.iter()
.filter(|&t| {
t != &i64::MAX.to_dim()
&& !flatten
.iter()
.any(|other| (t.clone() - other).to_i64().is_ok_and(|i| i > 0))
})
.cloned()
.collect();
if new_terms.len() == 0 {
i64::MAX.to_dim()
} else if new_terms.len() == 1 {
new_terms.remove(0)
} else {
Min(new_terms)
}
}
Max(terms) => {
let flatten: Vec<TDim> = terms
.into_iter()
.map(TDim::simplify)
.flat_map(|t| if let Max(t) = t { t } else { vec![t] })
.sorted_by(tdim_compare)
.dedup()
.collect();
let mut new_terms: Vec<TDim> = flatten
.iter()
.filter(|&t| {
t != &i64::MIN.to_dim()
&& !flatten
.iter()
.any(|other| (t.clone() - other).to_i64().is_ok_and(|i| i < 0))
})
.cloned()
.collect();
if new_terms.len() == 0 {
i64::MIN.to_dim()
} else if new_terms.len() == 1 {
new_terms.remove(0)
} else {
Max(new_terms)
}
}
Val(_) | Sym(_) => self,
}
}
Expand All @@ -435,6 +516,8 @@ impl TDim {
}
MulInt(p, a) => a.gcd() * p.unsigned_abs(),
Mul(terms) => terms.iter().map(|t| t.gcd()).product(),
Min(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
Max(terms) => terms.iter().map(|t| t.gcd()).reduce(|a, b| a.gcd(&b)).unwrap(),
Div(a, q) => {
if a.gcd() % *q == 0 {
a.gcd() / *q
Expand All @@ -456,6 +539,8 @@ impl TDim {
Val(v) => Val(v / d as i64),
Sym(_) => panic!(),
Add(terms) => Add(terms.iter().map(|t| t.div(d)).collect()),
Min(terms) => Min(terms.iter().map(|t| t.div(d)).collect()),
Max(terms) => Max(terms.iter().map(|t| t.div(d)).collect()),
Broadcast(terms) => Broadcast(terms.iter().map(|t| t.div(d)).collect()),
Mul(_) => Div(Box::new(self.clone()), d),
MulInt(p, a) => {
Expand Down Expand Up @@ -496,6 +581,8 @@ impl TDim {
(n, d * *q as i64)
}
Broadcast(terms) => slope_rec(&terms[0], sym),
Min(terms) => slope_rec(&terms[0], sym),
Max(terms) => slope_rec(&terms[0], sym),
}
}
let (p, q) = slope_rec(self, sym);
Expand All @@ -507,7 +594,7 @@ impl TDim {
match self {
Val(_) => maplit::hashset!(),
Sym(s) => maplit::hashset!(s.clone()),
Add(terms) | Mul(terms) | Broadcast(terms) => {
Add(terms) | Mul(terms) | Broadcast(terms) | Min(terms) | Max(terms) => {
terms.iter().fold(maplit::hashset!(), |mut set, v| {
set.extend(v.symbols());
set
Expand Down Expand Up @@ -995,4 +1082,29 @@ mod tests {
let mul2 = (term.clone() - 1) * (term.clone() * 2 - 3);
assert_eq!(mul1, mul2);
}

#[test]
fn min_ints_1() {
assert_eq!(2.to_dim().mini(1.to_dim()), 1.to_dim());
}

#[test]
fn min_ints_2() {
assert_eq!(1.to_dim().mini(2.to_dim()), 1.to_dim());
}

#[test]
fn min_same() {
assert_eq!(s().mini(s()), s());
}

#[test]
fn min_noop() {
assert_eq!(s().mini(1.to_dim()), s().mini(1.to_dim()));
}

#[test]
fn min_diff_1() {
assert_eq!((s() + 1).mini(s() + 2), s() + 1);
}
}
5 changes: 5 additions & 0 deletions nnef/src/ops/nnef/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tract_core::ops::cnn::deconv::adjustments;
use tract_core::ops::cnn::PaddingSpec;
use tract_core::ops::cnn::PoolSpec;
use tract_core::ops::konst::Const;
use tract_core::ops::math::min;
use tract_core::ops::nn::{DataFormat, Softmax, SoftmaxExp};
use tract_itertools::Itertools;

Expand Down Expand Up @@ -159,11 +160,14 @@ pub fn slice(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tra
let strides: TVec<isize> =
invocation.named_arg_as(builder, "stride").unwrap_or_else(|_| tvec!(1; axes.len()));
for (ix, axis) in axes.into_iter().enumerate() {
let axis_len =
builder.wire_as_outlets(Const(rctensor0(input_fact.shape[axis].clone())), &[])?[0];
let b = builder.wire_as_outlets(
tract_core::ops::array::Slice { axis: 0, start: ix.into(), end: ix.to_dim() + 1 },
&[begins],
)?;
let mut b = builder.wire_as_outlets(tract_core::ops::change_axes::AxisOp::Rm(0), &b)?;
b = builder.wire_as_outlets(min(), &[b[0], axis_len])?;
if let Some(k) = &builder.model.outlet_fact(b[0])?.konst {
if let Ok(i) = k.cast_to_scalar::<i64>() {
if i < 0 {
Expand All @@ -179,6 +183,7 @@ pub fn slice(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tra
&[ends],
)?;
let mut e = builder.wire_as_outlets(tract_core::ops::change_axes::AxisOp::Rm(0), &e)?;
e = builder.wire_as_outlets(min(), &[e[0], axis_len])?;
// use "<=", no "<" end[axis] = 0 means "up to the end"
// CAUTION: this notation is 1/ deprecated 2/ invalid with non trivial slicing
if let Some(k) = &builder.model.outlet_fact(e[0])?.konst {
Expand Down
1 change: 1 addition & 0 deletions nnef/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ pub fn tdim(dim: &TDim) -> RValue {
TDim::MulInt(x, y) => RValue::Binary(numeric(x).boxed(), "*".to_string(), tdim(y).boxed()),
TDim::Div(x, y) => RValue::Binary(tdim(x).boxed(), "/".to_string(), numeric(y).boxed()),
TDim::Broadcast(_) => todo!(),
TDim::Min(_) | TDim::Max(_) => todo!(),
}
}

Expand Down

0 comments on commit 1cc9b8a

Please sign in to comment.