Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx: ReduceMin/Max Ops #2563

Merged
merged 13 commits into from
Oct 15, 2024
174 changes: 173 additions & 1 deletion candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

pub type Value = Tensor;

Expand Down Expand Up @@ -1189,6 +1189,92 @@ fn simple_eval_(
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax
"ReduceMax" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;

let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};

let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();

let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};

axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();

if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}

if axes.len() > 1 {
axes.sort();
}

Some(axes)
} else {
None
};

// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}

let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.max_keepdim(axis)?
} else {
result.max(axis)?
}
}

result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.max_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.max(0)?
}
}
};

values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
Expand All @@ -1212,6 +1298,92 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin
"ReduceMin" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;

let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};

let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();

let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};

axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();

if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}

if axes.len() > 1 {
axes.sort();
}

Some(axes)
} else {
None
};

// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}

let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.min_keepdim(axis)?
} else {
result.min(axis)?
}
}

result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.min_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.min(0)?
}
}
};

values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {
Expand Down
Loading
Loading