Skip to content

Commit

Permalink
separate weird onnx rules from the rest
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 29, 2023
1 parent f4a07c4 commit b3d928d
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 82 deletions.
27 changes: 10 additions & 17 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::ops;
use crate::ops::array::Pad;
use crate::ops::array::PadMode;
use crate::ops::binary::TypedBinOp;
use crate::ops::cnn::PaddingSpec;
use crate::ops::cnn::PaddingSpec::*;
use crate::ops::einsum::EinSum;
use crate::ops::math::Add;
use crate::ops::math::Div;
Expand Down Expand Up @@ -382,8 +382,7 @@ impl ConvUnary {
pads,
};
wire = model.wire_node(format!("{name}.pad"), op, &[wire])?[0];
let valid_pool_spec =
PoolSpec { padding: ops::cnn::PaddingSpec::Valid, ..self.pool_spec.clone() };
let valid_pool_spec = PoolSpec { padding: Valid, ..self.pool_spec.clone() };
b_fact = model.outlet_fact(wire)?.clone();
let concrete_shape = b_fact.shape.as_concrete().unwrap();
input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
Expand Down Expand Up @@ -640,9 +639,7 @@ impl ConvUnary {
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.pool_spec.padding != PaddingSpec::Valid
&& !matches!(self.pool_spec.padding, PaddingSpec::ExplicitOnnxPool(_, _, _))
{
if matches!(self.pool_spec.padding, ExplicitOnnxPool(_, _, _) | SameLower | SameUpper) {
return Ok(None);
}
let prec = model.node(node.inputs[0].node);
Expand All @@ -661,11 +658,11 @@ impl ConvUnary {
}
let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
if let PaddingSpec::ExplicitOnnxPool(bef, aft, false) = &self.pool_spec.padding {
if let Explicit(bef, aft) = &self.pool_spec.padding {
izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
}
let padding = PaddingSpec::ExplicitOnnxPool(before, after, false);
let padding = Explicit(before, after);
let mut new = self.clone();
new.pool_spec.padding = padding;
let mut patch = TypedModelPatch::default();
Expand Down Expand Up @@ -825,8 +822,8 @@ impl TypedOp for ConvUnary {
self
);
}
if let PaddingSpec::ExplicitOnnxPool(before, after, _) = &self.pool_spec.padding {
anyhow::ensure!(before.len() == self.pool_spec.rank());
if let ExplicitOnnxPool(bef, after, _) | Explicit(bef, after) = &self.pool_spec.padding {
anyhow::ensure!(bef.len() == self.pool_spec.rank());
anyhow::ensure!(after.len() == self.pool_spec.rank());
}
if let Some(bias) = &self.bias {
Expand Down Expand Up @@ -1117,7 +1114,6 @@ fn should_use_lazy(_input_shape: &DataShape, pool_spec: &PoolSpec, group: usize)
mod test {
use super::*;
use crate::ops::array::Pad;
use crate::ops::cnn::PaddingSpec;
use DataFormat::*;

#[test]
Expand All @@ -1126,7 +1122,7 @@ mod test {
pool_spec: PoolSpec {
data_format: NCHW,
kernel_shape: tvec!(2, 2),
padding: PaddingSpec::Valid,
padding: Valid,
dilations: None,
strides: None,
output_channel_override: Some(1),
Expand Down Expand Up @@ -1171,7 +1167,7 @@ mod test {
dilations: None,
strides: None,
kernel_shape: tvec![2],
padding: crate::ops::cnn::PaddingSpec::ExplicitOnnxPool(tvec![0], tvec![0], false),
padding: Explicit(tvec![0], tvec![0]),
output_channel_override: Some(1),
},
kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
Expand All @@ -1186,10 +1182,7 @@ mod test {
model.declutter()?;
assert_eq!(model.nodes().len(), 2); // source + conv
let cv = model.nodes()[1].op_as::<ConvUnary>().unwrap();
assert_eq!(
cv.pool_spec.padding,
crate::ops::cnn::PaddingSpec::ExplicitOnnxPool(tvec![1], tvec![0], false)
); // source + conv
assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); // source + conv
Ok(())
}
}
80 changes: 44 additions & 36 deletions core/src/ops/cnn/padding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub enum PaddingSpec {
SameLower,
}

use PaddingSpec::*;

#[derive(Debug, Clone, new, PartialEq, Eq)]
pub struct ComputedPaddedDim<D: DimLike> {
pub deconvoluted: D,
Expand All @@ -21,24 +23,40 @@ pub struct ComputedPaddedDim<D: DimLike> {
impl PaddingSpec {
pub fn valid_dim(&self, d: usize, stride_is_one: bool) -> bool {
match self {
PaddingSpec::Valid => true,
PaddingSpec::ExplicitOnnxPool(a, b, ceil_mode) => {
Valid => true,
Explicit(bef, aft) => bef[d] == 0 && aft[d] == 0,
ExplicitOnnxPool(a, b, ceil_mode) => {
(*ceil_mode || stride_is_one) && a[d] == 0 && b[d] == 0
}
_ => false,
}
}

pub fn rm_axis(&self, d: usize) -> PaddingSpec {
match self {
PaddingSpec::ExplicitOnnxPool(a, b, ceil_mode) => {
let mut a = a.clone();
let mut b = b.clone();
a.remove(d);
b.remove(d);
PaddingSpec::ExplicitOnnxPool(a, b, *ceil_mode)
pub fn change_geo_axes(&self, op: &AxisOp) -> TractResult<PaddingSpec> {
match &self {
ExplicitOnnxPool(before, after, round) => {
let mut before: TVec<usize> = before.clone();
let mut after: TVec<usize> = after.clone();
op.change_shape_array(&mut before, false)?;
op.change_shape_array(&mut after, false)?;
if let AxisOp::Add(add) = op {
before[*add] = 0;
after[*add] = 0;
}
Ok(ExplicitOnnxPool(before, after, *round))
}
_ => self.clone(),
Explicit(before, after) => {
let mut before: TVec<usize> = before.clone();
let mut after: TVec<usize> = after.clone();
op.change_shape_array(&mut before, false)?;
op.change_shape_array(&mut after, false)?;
if let AxisOp::Add(add) = op {
before[*add] = 0;
after[*add] = 0;
}
Ok(Explicit(before, after))
}
Valid | SameLower | SameUpper => Ok(self.clone()),
}
}

Expand Down Expand Up @@ -92,19 +110,17 @@ impl PaddingSpec {
dilation: usize,
stride: usize,
) -> ComputedPaddedDim<D> {
let r = match self {
PaddingSpec::Valid => Self::valid(input, kernel, dilation, stride),
PaddingSpec::Explicit(ref bef, ref aft) => {
match self {
Valid => Self::valid(input, kernel, dilation, stride),
Explicit(ref bef, ref aft) => {
Self::explicit(input, kernel, dilation, stride, bef[axis], aft[axis])
}
PaddingSpec::ExplicitOnnxPool(ref bef, ref aft, ceil_mode) => Self::explicit_onnx_pool(
ExplicitOnnxPool(ref bef, ref aft, ceil_mode) => Self::explicit_onnx_pool(
input, kernel, dilation, stride, bef[axis], aft[axis], *ceil_mode,
),
PaddingSpec::SameUpper => Self::same(input, kernel, dilation, stride, true),
PaddingSpec::SameLower => Self::same(input, kernel, dilation, stride, false),
};
eprintln!("{self:?} axis:{axis} input:{input:?} kernel:{kernel} dilation:{dilation} stride:{stride} => {r:?}");
r
SameUpper => Self::same(input, kernel, dilation, stride, true),
SameLower => Self::same(input, kernel, dilation, stride, false),
}
}

pub fn compute_one_for_deconv<D: DimLike>(
Expand All @@ -117,24 +133,16 @@ impl PaddingSpec {
adjustment: usize,
) -> TractResult<ComputedPaddedDim<D>> {
match self {
PaddingSpec::Valid => {
Self::valid_for_deconv(input, kernel, dilation, stride, adjustment)
}
PaddingSpec::SameUpper => {
Self::same_for_deconv(input, kernel, dilation, stride, adjustment, true)
}
PaddingSpec::SameLower => {
Self::same_for_deconv(input, kernel, dilation, stride, adjustment, false)
}
PaddingSpec::Explicit(ref bef, ref aft) => Self::explicit_for_deconv(
Valid => Self::valid_for_deconv(input, kernel, dilation, stride, adjustment),
SameUpper => Self::same_for_deconv(input, kernel, dilation, stride, adjustment, true),
SameLower => Self::same_for_deconv(input, kernel, dilation, stride, adjustment, false),
Explicit(ref bef, ref aft) => Self::explicit_for_deconv(
input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
),
// unreachable ?
PaddingSpec::ExplicitOnnxPool(ref bef, ref aft, _ceil_mode) => {
Self::explicit_for_deconv(
input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
)
}
ExplicitOnnxPool(ref bef, ref aft, _ceil_mode) => Self::explicit_for_deconv(
input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
),
}
}

Expand Down Expand Up @@ -184,7 +192,7 @@ impl PaddingSpec {
} else {
let kernel_field = (kernel - 1) * dilation + 1;
let dividend = input.clone() + bef + aft - kernel_field;
let output = dividend.div(stride) + 1;
let output = dividend.divceil(stride) + 1;
ComputedPaddedDim::new(input.clone(), output, bef.into(), aft.into())
}
}
Expand Down
14 changes: 1 addition & 13 deletions core/src/ops/cnn/pools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,7 @@ impl PoolSpec {
op.change_shape_array(&mut kernel_shape, false)?;
let mut strides = self.strides().into_owned().into();
op.change_shape_array(&mut strides, false)?;
let padding = if let PaddingSpec::ExplicitOnnxPool(before, after, round) = &self.padding {
let mut before: TVec<usize> = before.clone();
let mut after: TVec<usize> = after.clone();
op.change_shape_array(&mut before, false)?;
op.change_shape_array(&mut after, false)?;
if let AxisOp::Add(add) = op {
before[*add] = 0;
after[*add] = 0;
}
PaddingSpec::ExplicitOnnxPool(before, after, *round)
} else {
self.padding.clone()
};
let padding = self.padding.change_geo_axes(op)?;
Ok(PoolSpec {
data_format: self.data_format,
kernel_shape,
Expand Down
4 changes: 2 additions & 2 deletions nnef/src/ops/nnef/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ pub fn make_conv_named_args<'a>(
use tract_core::ops::cnn::PaddingSpec;
let output_shape = pool_spec.data_format.shape(node.outputs[0].fact.shape.to_tvec())?;
let padding = match &pool_spec.padding {
PaddingSpec::ExplicitOnnxPool(bef, after, _) => array(
PaddingSpec::ExplicitOnnxPool(bef, after, _) | PaddingSpec::Explicit(bef, after) => array(
&bef.iter()
.zip(after.iter())
.map(|(a, b)| tuple_2(numeric(a), numeric(b)))
Expand Down Expand Up @@ -333,7 +333,7 @@ fn cnn_pool(
wire = ast.force_variable(format!("{}_input", node.name), &wire);
let conv_fragment = cnn_pool_fragment(ast, pool_spec.data_format, pool_spec.rank(), op_name);
let padding = match &pool_spec.padding {
PaddingSpec::ExplicitOnnxPool(bef, after, _) => Some(
PaddingSpec::ExplicitOnnxPool(bef, after, _) | PaddingSpec::Explicit(bef, after) => Some(
bef.iter()
.zip(after.iter())
.map(|(a, b)| tuple_2(numeric(a), numeric(b)))
Expand Down
2 changes: 1 addition & 1 deletion onnx/src/ops/nn/conv_transpose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub fn conv_transpose(
_ctx: &ParsingContext,
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let padding_spec = super::pad(node)?;
let padding_spec = super::pad(node, false)?;
let strides = super::strides(node)?;
let dilations = super::dilations(node)?;
let adjustments = node.get_attr_opt_tvec::<usize>("output_padding")?;
Expand Down
6 changes: 3 additions & 3 deletions onnx/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ pub fn batch_normalization(
}

fn common_conv(node: &NodeProto) -> TractResult<cnn::Conv> {
let mut op = ops::cnn::Conv::default().padding(pad(node)?);
let mut op = ops::cnn::Conv::default().padding(pad(node, false)?);
if let Some(kernel_shape) = node.get_attr_opt_tvec("kernel_shape")? {
op = op.kernel_shape(kernel_shape);
}
Expand Down Expand Up @@ -197,7 +197,7 @@ pub fn average_pool(
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let kernel_shape = node.get_attr_tvec("kernel_shape")?;
let pad = pad(node)?;
let pad = pad(node, true)?;
let strides = strides(node)?;
let count_include_pad = node.get_attr_opt("count_include_pad")?.unwrap_or(false);
Ok((
Expand Down Expand Up @@ -284,7 +284,7 @@ pub fn max_pool(
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let kernel_shape = node.get_attr_tvec("kernel_shape")?;
let pad = pad(node)?;
let pad = pad(node, true)?;
let strides = strides(node)?;
Ok((
Box::new(cnn::MaxPool::new(
Expand Down
Loading

0 comments on commit b3d928d

Please sign in to comment.