Skip to content

Commit

Permalink
optional second input for Squeeze (bad bad bad idea)
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Aug 15, 2023
1 parent 0ee4781 commit 8864e56
Showing 1 changed file with 39 additions and 16 deletions.
55 changes: 39 additions & 16 deletions onnx/src/ops/array/squeeze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ pub fn squeeze(
#[derive(Debug, Clone, Hash)]
struct Squeeze13;



impl Expansion for Squeeze13 {
fn name(&self) -> Cow<str> {
"Squeeze13".into()
Expand All @@ -31,16 +29,29 @@ impl Expansion for Squeeze13 {
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult {
check_input_arity(inputs, 2)?;
check_output_arity(outputs, 1)?;
s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
s.given_2(&inputs[0].shape, &inputs[1].value, move |s, shape, axes| {
let axes =
axes.cast_to::<i64>()?.as_slice::<i64>()?.iter().map(|i| *i as isize).collect();
let op = tract_hir::ops::array::Squeeze::new(Some(axes));
let out_shape = op.output_shape(&shape)?;
s.equals(&outputs[0].shape, out_shape)
})
if inputs.len() == 2 {
s.given_2(&inputs[0].shape, &inputs[1].value, move |s, shape, axes| {
let axes =
axes.cast_to::<i64>()?.as_slice::<i64>()?.iter().map(|i| *i as isize).collect();
let op = tract_hir::ops::array::Squeeze::new(Some(axes));
let out_shape = op.output_shape(&shape)?;
s.equals(&outputs[0].shape, out_shape)
})
} else {
s.given(&inputs[0].shape, move |s, shape| {
let axes = shape
.iter()
.enumerate()
.filter(|(_, dim)| dim.is_one())
.map(|(pos, _)| pos as isize)
.collect();
let op = tract_hir::ops::array::Squeeze::new(Some(axes));
let out_shape = op.output_shape(&shape)?;
s.equals(&outputs[0].shape, out_shape)
})
}
}

fn wire(
Expand All @@ -49,14 +60,26 @@ impl Expansion for Squeeze13 {
model: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
if let Some(axes) = model.outlet_fact(inputs[1])?.konst.as_ref() {
let axes =
axes.cast_to::<i64>()?.as_slice::<i64>()?.iter().map(|i| *i as isize).collect();
if inputs.len() == 2 {
if let Some(axes) = model.outlet_fact(inputs[1])?.konst.as_ref() {
let axes =
axes.cast_to::<i64>()?.as_slice::<i64>()?.iter().map(|i| *i as isize).collect();
let op = tract_hir::ops::array::Squeeze::new(Some(axes));
op.wire(prefix, model, &inputs[0..1])
} else {
bail!("Need axes to be a constant")
}
} else {
let axes = model
.outlet_fact(inputs[0])?
.shape
.iter()
.enumerate()
.filter(|(_, dim)| dim.is_one())
.map(|(pos, _)| pos as isize)
.collect();
let op = tract_hir::ops::array::Squeeze::new(Some(axes));
op.wire(prefix, model, &inputs[0..1])
} else {
bail!("Need axes to be a constant")
}
}

}

0 comments on commit 8864e56

Please sign in to comment.