Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 28, 2023
1 parent 6e1226f commit f33c0e7
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 57 deletions.
12 changes: 6 additions & 6 deletions core/src/ops/cnn/conv/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ impl ConvUnary {
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.pool_spec.padding != PaddingSpec::Valid
&& !matches!(self.pool_spec.padding, PaddingSpec::Explicit(_, _, _))
&& !matches!(self.pool_spec.padding, PaddingSpec::ExplicitOnnxPool(_, _, _))
{
return Ok(None);
}
Expand All @@ -661,11 +661,11 @@ impl ConvUnary {
}
let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
if let PaddingSpec::Explicit(bef, aft, false) = &self.pool_spec.padding {
if let PaddingSpec::ExplicitOnnxPool(bef, aft, false) = &self.pool_spec.padding {
izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
}
let padding = PaddingSpec::Explicit(before, after, false);
let padding = PaddingSpec::ExplicitOnnxPool(before, after, false);
let mut new = self.clone();
new.pool_spec.padding = padding;
let mut patch = TypedModelPatch::default();
Expand Down Expand Up @@ -825,7 +825,7 @@ impl TypedOp for ConvUnary {
self
);
}
if let PaddingSpec::Explicit(before, after, _) = &self.pool_spec.padding {
if let PaddingSpec::ExplicitOnnxPool(before, after, _) = &self.pool_spec.padding {
anyhow::ensure!(before.len() == self.pool_spec.rank());
anyhow::ensure!(after.len() == self.pool_spec.rank());
}
Expand Down Expand Up @@ -1171,7 +1171,7 @@ mod test {
dilations: None,
strides: None,
kernel_shape: tvec![2],
padding: crate::ops::cnn::PaddingSpec::Explicit(tvec![0], tvec![0], false),
padding: crate::ops::cnn::PaddingSpec::ExplicitOnnxPool(tvec![0], tvec![0], false),
output_channel_override: Some(1),
},
kernel_fmt: crate::ops::cnn::KernelFormat::OIHW,
Expand All @@ -1188,7 +1188,7 @@ mod test {
let cv = model.nodes()[1].op_as::<ConvUnary>().unwrap();
assert_eq!(
cv.pool_spec.padding,
crate::ops::cnn::PaddingSpec::Explicit(tvec![1], tvec![0], false)
crate::ops::cnn::PaddingSpec::ExplicitOnnxPool(tvec![1], tvec![0], false)
); // source + conv
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/cnn/deconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn adjustments(
debug_assert_eq!(pool_spec.rank(), output_geo.len());
let rank = pool_spec.rank();
let pad: TVec<usize> = match &pool_spec.padding {
PaddingSpec::Explicit(beg, end, _) => (0..rank).map(|r| beg[r] + end[r]).collect(),
PaddingSpec::ExplicitOnnxPool(beg, end, _) => (0..rank).map(|r| beg[r] + end[r]).collect(),
PaddingSpec::Valid => tvec!(0; rank),
_ => todo!("Unsupported combination of deconvolution arguments"),
};
Expand Down
76 changes: 63 additions & 13 deletions core/src/ops/cnn/padding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use crate::internal::*;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub enum PaddingSpec {
Explicit(TVec<usize>, TVec<usize>, bool),
Explicit(TVec<usize>, TVec<usize>),
ExplicitOnnxPool(TVec<usize>, TVec<usize>, bool),
#[default]
Valid,
SameUpper,
Expand All @@ -21,7 +22,7 @@ impl PaddingSpec {
pub fn valid_dim(&self, d: usize, stride_is_one: bool) -> bool {
match self {
PaddingSpec::Valid => true,
PaddingSpec::Explicit(a, b, ceil_mode) => {
PaddingSpec::ExplicitOnnxPool(a, b, ceil_mode) => {
(*ceil_mode || stride_is_one) && a[d] == 0 && b[d] == 0
}
_ => false,
Expand All @@ -30,12 +31,12 @@ impl PaddingSpec {

pub fn rm_axis(&self, d: usize) -> PaddingSpec {
match self {
PaddingSpec::Explicit(a, b, ceil_mode) => {
PaddingSpec::ExplicitOnnxPool(a, b, ceil_mode) => {
let mut a = a.clone();
let mut b = b.clone();
a.remove(d);
b.remove(d);
PaddingSpec::Explicit(a, b, *ceil_mode)
PaddingSpec::ExplicitOnnxPool(a, b, *ceil_mode)
}
_ => self.clone(),
}
Expand Down Expand Up @@ -93,9 +94,12 @@ impl PaddingSpec {
) -> ComputedPaddedDim<D> {
let r = match self {
PaddingSpec::Valid => Self::valid(input, kernel, dilation, stride),
PaddingSpec::Explicit(ref bef, ref aft, ceil_mode) => {
Self::explicit(input, kernel, dilation, stride, bef[axis], aft[axis], *ceil_mode)
PaddingSpec::Explicit(ref bef, ref aft) => {
Self::explicit(input, kernel, dilation, stride, bef[axis], aft[axis])
}
PaddingSpec::ExplicitOnnxPool(ref bef, ref aft, ceil_mode) => Self::explicit_onnx_pool(
input, kernel, dilation, stride, bef[axis], aft[axis], *ceil_mode,
),
PaddingSpec::SameUpper => Self::same(input, kernel, dilation, stride, true),
PaddingSpec::SameLower => Self::same(input, kernel, dilation, stride, false),
};
Expand All @@ -122,9 +126,15 @@ impl PaddingSpec {
PaddingSpec::SameLower => {
Self::same_for_deconv(input, kernel, dilation, stride, adjustment, false)
}
PaddingSpec::Explicit(ref bef, ref aft, _ceil_mode) => Self::explicit_for_deconv(
PaddingSpec::Explicit(ref bef, ref aft) => Self::explicit_for_deconv(
input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
),
// unreachable ?
PaddingSpec::ExplicitOnnxPool(ref bef, ref aft, _ceil_mode) => {
Self::explicit_for_deconv(
input, kernel, dilation, stride, bef[axis], aft[axis], adjustment,
)
}
}
}

Expand Down Expand Up @@ -162,10 +172,48 @@ impl PaddingSpec {
stride: usize,
bef: usize,
aft: usize,
) -> ComputedPaddedDim<D> {
if let Ok(i) = input.to_dim().to_usize() {
let ints = Self::explicit_usize(i, kernel, dilation, stride, bef, aft);
ComputedPaddedDim::new(
input.clone(),
ints.convoluted.into(),
ints.pad_before.into(),
ints.pad_after.into(),
)
} else {
let kernel_field = (kernel - 1) * dilation + 1;
let dividend = input.clone() + bef + aft - kernel_field;
let output = dividend.div(stride) + 1;
ComputedPaddedDim::new(input.clone(), output, bef.into(), aft.into())
}
}
fn explicit_usize(
input: usize,
kernel: usize,
dilation: usize,
stride: usize,
bef: usize,
aft: usize,
) -> ComputedPaddedDim<usize> {
let kernel_field = (kernel - 1) * dilation + 1;
let dividend = (input + bef + aft).saturating_sub(kernel_field);
let output = dividend / stride + 1;
ComputedPaddedDim::new(input, output, bef, aft)
}

fn explicit_onnx_pool<D: DimLike>(
input: &D,
kernel: usize,
dilation: usize,
stride: usize,
bef: usize,
aft: usize,
ceil_mode: bool,
) -> ComputedPaddedDim<D> {
if let Ok(i) = input.to_dim().to_usize() {
let ints = Self::explicit_usize(i, kernel, dilation, stride, bef, aft, ceil_mode);
let ints =
Self::explicit_onnx_pool_usize(i, kernel, dilation, stride, bef, aft, ceil_mode);
ComputedPaddedDim::new(
input.clone(),
ints.convoluted.into(),
Expand All @@ -182,7 +230,7 @@ impl PaddingSpec {
}
}

fn explicit_usize(
fn explicit_onnx_pool_usize(
input: usize,
kernel: usize,
dilation: usize,
Expand All @@ -202,7 +250,6 @@ impl PaddingSpec {
output -= 1;
}
}
let after = (output * stride) + kernel_field - 1 - input - bef;
ComputedPaddedDim::new(input, output, bef, aft)
}

Expand Down Expand Up @@ -317,7 +364,7 @@ mod tests {
#[test]
fn explicit_2() {
assert_eq!(
PS::explicit(&28usize, 3usize, 1, 1, 2, 2, true),
PS::explicit_onnx_pool(&28usize, 3usize, 1, 1, 2, 2, true),
ComputedPaddedDim::new(28, 30, 2, 2)
);
}
Expand All @@ -326,7 +373,7 @@ mod tests {
#[ignore = "ONNX weird output computation for explicit"]
fn explicit_3() {
assert_eq!(
PS::explicit(&2usize, 1usize, 1, 2, 0, 0, true),
PS::explicit_onnx_pool(&2usize, 1usize, 1, 2, 0, 0, true),
ComputedPaddedDim::new(2, 2, 0, 0)
);
}
Expand All @@ -340,6 +387,9 @@ mod tests {
// 012 345 678 9ab
#[test]
fn bug_explicit_stride() {
assert_eq!(PS::explicit(&12usize, 3usize, 1, 3, 0, 0, false), ComputedPaddedDim::new(12, 4, 0, 0));
assert_eq!(
PS::explicit_onnx_pool(&12usize, 3usize, 1, 3, 0, 0, false),
ComputedPaddedDim::new(12, 4, 0, 0)
);
}
}
2 changes: 1 addition & 1 deletion core/src/ops/cnn/patches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ pub mod test {
.unwrap()
.with_dilations(tvec!(dilation))
.with_kernel_shape(tvec!(kdim))
.with_padding(PaddingSpec::Explicit(tvec![pad_before], tvec![bad_after], true))
.with_padding(PaddingSpec::ExplicitOnnxPool(tvec![pad_before], tvec![bad_after], true))
.with_strides(tvec![stride])
.into_patch();
patch.output_shape[0]
Expand Down
4 changes: 2 additions & 2 deletions core/src/ops/cnn/pools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl PoolSpec {
op.change_shape_array(&mut kernel_shape, false)?;
let mut strides = self.strides().into_owned().into();
op.change_shape_array(&mut strides, false)?;
let padding = if let PaddingSpec::Explicit(before, after, round) = &self.padding {
let padding = if let PaddingSpec::ExplicitOnnxPool(before, after, round) = &self.padding {
let mut before: TVec<usize> = before.clone();
let mut after: TVec<usize> = after.clone();
op.change_shape_array(&mut before, false)?;
Expand All @@ -101,7 +101,7 @@ impl PoolSpec {
before[*add] = 0;
after[*add] = 0;
}
PaddingSpec::Explicit(before, after, *round)
PaddingSpec::ExplicitOnnxPool(before, after, *round)
} else {
self.padding.clone()
};
Expand Down
6 changes: 3 additions & 3 deletions harness/core-proptest-pulse/src/conv_plus_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl Arbitrary for ConvOp {
.prop_flat_map(|(stride, dil, ker)| {
let padding = (ker - 1) * dil;
let explicit = (0..=padding).prop_map(move |right| {
PaddingSpec::Explicit(tvec!(padding - right), tvec!(right), false)
PaddingSpec::ExplicitOnnxPool(tvec!(padding - right), tvec!(right), false)
});
(Just((stride, dil, ker)), prop_oneof![Just(PaddingSpec::Valid), explicit])
})
Expand Down Expand Up @@ -257,7 +257,7 @@ fn stride() {
stride: 2,
dilation: 1,
ker: t(2),
padding: PaddingSpec::Explicit(tvec!(1), tvec!(0), false),
padding: PaddingSpec::ExplicitOnnxPool(tvec!(1), tvec!(0), false),
}],
};
cpc.run().unwrap();
Expand All @@ -275,7 +275,7 @@ fn three() {
stride: 1,
dilation: 1,
ker: t(2),
padding: PaddingSpec::Explicit(tvec!(1), tvec!(0), false),
padding: PaddingSpec::ExplicitOnnxPool(tvec!(1), tvec!(0), false),
},
],
};
Expand Down
2 changes: 1 addition & 1 deletion harness/core-proptest-pulse/src/deconv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ fn deconv2d() {
pool_spec: PoolSpec {
data_format: tract_hir::ops::nn::DataFormat::NCHW,
kernel_shape: tvec!(1, 3),
padding: cnn::PaddingSpec::Explicit(tvec!(0, 1), tvec!(0, 1), false),
padding: cnn::PaddingSpec::Explicit(tvec!(0, 1), tvec!(0, 1)),
strides: Some(tvec!(1, 2)),
dilations: Some(tvec![1, 1]),
output_channel_override: Some(2),
Expand Down
6 changes: 3 additions & 3 deletions harness/core-proptest-pulse/src/delay_plus_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl Arbitrary for DelayPlusPoolProblem {
.prop_flat_map(|(pool_window, factor, delay, stride)| {
let padding = pool_window - 1;
let explicit = (0..=padding).prop_map(move |right| {
PaddingSpec::Explicit(tvec!(padding - right), tvec!(right), false)
PaddingSpec::ExplicitOnnxPool(tvec!(padding - right), tvec!(right), false)
});
let min_input = delay + pool_window;
(
Expand Down Expand Up @@ -195,7 +195,7 @@ fn test_pad_right() {
delay: 0,
stride: 1,
pool_window: 2,
padding: PaddingSpec::Explicit(tvec!(0), tvec!(1), false),
padding: PaddingSpec::ExplicitOnnxPool(tvec!(0), tvec!(1), false),
}
.run()
.unwrap()
Expand All @@ -209,7 +209,7 @@ fn test_pad_right_2() {
delay: 1,
stride: 2,
pool_window: 2,
padding: PaddingSpec::Explicit(tvec!(0), tvec!(1), false),
padding: PaddingSpec::ExplicitOnnxPool(tvec!(0), tvec!(1), false),
}
.run()
.unwrap()
Expand Down
4 changes: 2 additions & 2 deletions nnef/src/ops/nnef/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ pub fn read_conv_parameters(
before.push(p[0]);
after.push(p[1]);
}
PaddingSpec::Explicit(before, after, false)
PaddingSpec::ExplicitOnnxPool(before, after, false)
};
let pool_spec = PoolSpec::new(
DataFormat::NCHW,
Expand Down Expand Up @@ -439,7 +439,7 @@ fn pool_spec_for_pools(
}
let spatial_pool_bef = DataFormat::NCHW.shape(&before)?.hw_dims().into();
let spatial_pool_aft = DataFormat::NCHW.shape(&after)?.hw_dims().into();
PaddingSpec::Explicit(spatial_pool_bef, spatial_pool_aft, false)
PaddingSpec::ExplicitOnnxPool(spatial_pool_bef, spatial_pool_aft, false)
};
Ok(PoolSpec::new(
DataFormat::NCHW,
Expand Down
4 changes: 2 additions & 2 deletions nnef/src/ops/nnef/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ pub fn make_conv_named_args<'a>(
use tract_core::ops::cnn::PaddingSpec;
let output_shape = pool_spec.data_format.shape(node.outputs[0].fact.shape.to_tvec())?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(bef, after, _) => array(
PaddingSpec::ExplicitOnnxPool(bef, after, _) => array(
&bef.iter()
.zip(after.iter())
.map(|(a, b)| tuple_2(numeric(a), numeric(b)))
Expand Down Expand Up @@ -333,7 +333,7 @@ fn cnn_pool(
wire = ast.force_variable(format!("{}_input", node.name), &wire);
let conv_fragment = cnn_pool_fragment(ast, pool_spec.data_format, pool_spec.rank(), op_name);
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(bef, after, _) => Some(
PaddingSpec::ExplicitOnnxPool(bef, after, _) => Some(
bef.iter()
.zip(after.iter())
.map(|(a, b)| tuple_2(numeric(a), numeric(b)))
Expand Down
23 changes: 13 additions & 10 deletions onnx/src/ops/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,25 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) {
reg.insert("Softsign", |_, _| Ok((expand(ops::activations::Softsign), vec![])));
}

fn pad(node: &NodeProto) -> TractResult<cnn::PaddingSpec> {
fn pad(node: &NodeProto, pool_rules: bool) -> TractResult<cnn::PaddingSpec> {
let ceil_mode = node.get_attr_opt::<isize>("ceil_mode")?.unwrap_or(0) == 1;
let default = match node.get_attr_opt_vec::<isize>("kernel_shape")? {
Some(shape) => {
cnn::PaddingSpec::Explicit(tvec!(0; shape.len()), tvec!(0; shape.len()), ceil_mode)
}
Some(shape) => cnn::PaddingSpec::ExplicitOnnxPool(
tvec!(0; shape.len()),
tvec!(0; shape.len()),
ceil_mode,
),
None => cnn::PaddingSpec::Valid,
};
if let Some(pads) = node.get_attr_opt_tvec("pads")? {
let len = pads.len();
return Ok(cnn::PaddingSpec::Explicit(
pads.iter().cloned().take(len / 2).collect(),
pads.iter().cloned().skip(len / 2).collect(),
ceil_mode,
));
let left = pads.iter().cloned().take(len / 2).collect();
let right = pads.iter().cloned().skip(len / 2).collect();
if pool_rules {
return Ok(cnn::PaddingSpec::ExplicitOnnxPool(left, right, ceil_mode));
} else {
return Ok(cnn::PaddingSpec::Explicit(left, right));
}
}
Ok(node
.get_attr_opt("auto_pad")?
Expand Down Expand Up @@ -303,7 +307,6 @@ pub fn parametric_softplus(
#[derive(Debug, Clone, Hash)]
struct Prelu;


impl Expansion for Prelu {
fn name(&self) -> Cow<str> {
"Prelu".into()
Expand Down
Loading

0 comments on commit f33c0e7

Please sign in to comment.