Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug 1645 (Unsqueeze OpSet 11) #1661

Merged
merged 4 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ fn main() {
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
laggui marked this conversation as resolved.
Show resolved Hide resolved
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
.input("tests/mask_where/mask_where.onnx")
.out_dir("model/")
.run_from_script();
Expand Down
26 changes: 25 additions & 1 deletion crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ include_models!(
conv_transpose2d,
pow,
pow_int,
unsqueeze
unsqueeze,
unsqueeze_opset16,
unsqueeze_opset11
);

#[cfg(test)]
Expand Down Expand Up @@ -1018,6 +1020,28 @@ mod tests {
assert_eq!(output.shape(), expected_shape);
}

#[test]
fn unsqueeze_opset16() {
let device = Default::default();
let model = unsqueeze_opset16::Model::<Backend>::new(&device);
let input_shape = Shape::from([3, 4, 5]);
let expected_shape = Shape::from([3, 4, 5, 1]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
}

#[test]
fn unsqueeze_opset11() {
let device = Default::default();
let model = unsqueeze_opset11::Model::<Backend>::new(&device);
let input_shape = Shape::from([3, 4, 5]);
let expected_shape = Shape::from([3, 4, 5, 1]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
}

#[test]
fn cast() {
let device = Default::default();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.2.1:Ž
:
onnx::Unsqueeze_01
/Unsqueeze" Unsqueeze*
axes@ 
main_graphZ'
onnx::Unsqueeze_0



b
1




B
laggui marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.
6 changes: 3 additions & 3 deletions crates/burn-import/onnx-tests/tests/unsqueeze/unsqueeze_torch.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def main():
model.eval()
device = torch.device("cpu")

file_name = "unsqueeze_torch.onnx"
test_input = torch.randn(3, 4, 5, device=device)
model = Model()

output = model.forward(test_input)

torch.onnx.export(model, (test_input), file_name, verbose=False, opset_version=16)
torch.onnx.export(model, (test_input), "unsqueeze_opset16.onnx", verbose=False, opset_version=16)
torch.onnx.export(model, (test_input), "unsqueeze_opset11.onnx", verbose=False, opset_version=11)

print(f"Finished exporting model to {file_name}")
print(f"Finished exporting model")

# Output some test data for use in the test
print(f"Test input data of ones: {test_input}")
Expand Down
109 changes: 30 additions & 79 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,92 +251,43 @@ fn reduce_mean_update_outputs(node: &mut Node) {
}
}

//fn __unsqueeze_shape
/// Infers the shape of the output from the input and axes
/// Right now, this should only be called if the rhs is a constant
/// Update the output tensor dimension based on the "axes" attribute or the second input
fn unsqueeze_update_output(node: &mut Node) {
if node.inputs.len() != 2 {
panic!("Unsqueeze: wrong number of inputs");
}
// get the values while making sure the types are correct
let (input, axes) = match (&node.inputs[0].ty, &node.inputs[1].ty) {
(ArgType::Tensor(tensor), ArgType::Tensor(_axes)) => (
tensor.clone(),
match &node.inputs[1].value {
Some(value) => match &value {
Data::Int64s(axes) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
None => None,
let axes = if node.inputs.len() == 2 {
// get the values while making sure the types are correct
match &node.inputs[1].value {
Some(value) => match value {
Data::Int64s(axes) => Some(axes.clone()),
_ => panic!("Unsqueeze: invalid input types"),
},
),
_ => panic!("Unsqueeze: invalid input types"),
None => None,
}
} else {
node.attrs
.iter()
.find_map(|(key, value)| match key.as_str() {
"axes" => Some(value.clone().into_i64s()),
_ => None,
})
laggui marked this conversation as resolved.
Show resolved Hide resolved
};
//need output way up here to avoid borrowing issues
let mut tensor = match &node.outputs[0].ty {

// need output way up here to avoid borrowing issues
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
};
match &axes {
//case 1: axes is constant -> output shape is input shape with 1s inserted at the axes
Some(dim_indices) => {
let output_rank = (dim_indices.len() + input.dim) as i64;
let mut dim_indices = dim_indices
.to_vec()
.iter()
.map(|&d| {
if (-output_rank..output_rank).contains(&d) {
(if d < 0 { d + output_rank } else { d }) as usize
} else {
panic!("Unsqueeze: invalid axis")
}
})
.collect::<Vec<usize>>();
dim_indices.sort_unstable();
let mut new_dims = vec![1; output_rank as usize];

tensor.dim = output_rank as usize;
let old_dims = input.shape.unwrap();
//Now use this to copy the chunks of the dims
let mut prev_idx: usize = 0;
let mut current_left_b: usize = 0;
let mut current_right_b: usize = 0;
let mut offset: usize = 0;

dim_indices.iter().for_each(|d| {
//check if there is space for at least one dimension
if prev_idx < *d {
current_right_b = *d - offset;

//copy the chunks of the dims
if current_right_b < old_dims.len() {
new_dims[prev_idx..*d]
.copy_from_slice(&old_dims[current_left_b..current_right_b])
} else {
new_dims[prev_idx..*d].copy_from_slice(&old_dims[current_left_b..]);
}
prev_idx = *d + 1;
//offset is equal to the number of extracted elements from the original shape
offset += current_right_b - current_left_b;
current_left_b = current_right_b;
} else {
//it's sorted so the only reason this would happen
//is if multiple indices are the same
prev_idx += 1;
}
});
//copy over anything past the index of the last new dimension
if current_left_b < old_dims.len() {
new_dims[prev_idx..].copy_from_slice(&old_dims[current_left_b..]);
}
tensor.shape = Some(new_dims);
node.outputs[0].ty = ArgType::Tensor(tensor.clone());
}

//case 3: output shape is dynamic -> black magic or unsupported
None => {
panic!("Unsqueeze: dynamic output shape is not currently supported");
}
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
};

if axes.is_some() {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim + axes.unwrap().len(),
shape: None, // shape is calculated at runtime
..output
});
}
}

Expand Down
5 changes: 4 additions & 1 deletion crates/burn-import/src/onnx/from_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,10 @@ impl OnnxGraphBuilder {
/// Needs to be called after node renaming to ensure that the rhs name is correct
/// Needs to be called after constant lifting to ensure that the rhs value exists
fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) {
if node.node_type == NodeType::Unsqueeze && node.inputs[1].value.is_none() {
if node.node_type == NodeType::Unsqueeze
&& node.inputs.len() > 1
&& node.inputs[1].value.is_none()
{
if let Some(in_arg) = graph_io.get_node_output(&node.outputs[0].name) {
remap_unsqueeze_to_reshape(node, in_arg);
}
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,19 @@ pub fn reshape_config(node: &Node) -> Vec<i64> {
//Note this function should only execute if the second input is a constant
//if it wasn't and the output shape was known, unsqueeze has been remapped to reshape
pub fn unsqueeze_config(node: &Node) -> Vec<i64> {
// Check if axes attribute exists
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => return value.clone().into_i64s(),
_ => {}
}
}

assert!(
!node.inputs.is_empty(),
"Unsqueeze: axes tensor must be present"
);

let input_value = &node.inputs[1];

match &node.inputs[1].ty {
Expand Down
Loading