-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Squeeze Onnx Import #1753
Squeeze Onnx Import #1753
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1753 +/- ##
==========================================
- Coverage 86.61% 86.41% -0.21%
==========================================
Files 700 735 +35
Lines 83423 85729 +2306
==========================================
+ Hits 72258 74083 +1825
- Misses 11165 11646 +481 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for completing a missing ONNX op in Burn. One down!
You've got almost perfect. We just need a couple fixes:
- Add test case to verify axes by adding ONNX OpSet 13 (please see my inlined comments)
- Account for negative values in axes.
match key.as_str() { | ||
"axes" => return value.clone().into_i64s(), | ||
_ => {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! This will support ONNX OPset 13
Can we, to be sure it works, add a unit test for OPset 13? This is similar to unsqueeze with opset 16 and 13.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsqueeze uses OPset 16 and 11. I think my original one used 16, so I added one for 13. Is that what you meant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, both should be good. OpSet 13 has axes
attribute so it should work.
} | ||
} | ||
_ => panic!("Arg for squeeze must be tensor or scalar"), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axes
can contain negative values, which means counting dimensions from the back (see squeeze spec), and Burn squeeze only supports positive values (see doc). So we should account for this.
We are already doing this for one dimension (see code) for gather_config
, so you can see the logic. We need to do this for all items in axes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! I have the same comments as already mentioned, and one question/comment regarding the squeeze op spec support.
This PR adds support for squeeze op on a single dim, so support is not complete but as long as it is explicit I have no issues adding full support if needed in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! I have two minor comments and then it should be good to go! 🙂
if let Some(Data::Int64s(axes)) = &node.inputs[1].value { | ||
if axes.len() != 1 { | ||
panic!("Squeeze: Only one axis should be specified for squeezing."); | ||
} | ||
} else { | ||
panic!("Squeeze: Axes input must be an integer list."); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're checking for axes in the attributes in squeeze_config
to support another opset we need to check here too to make sure the output is properly adjusted.
See the unsqueeze_update_output
function which is doing something similar (though it captures the axes to support unsqueeze on multiple axes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, should be a little closer to what you had in mind now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for making the changes. I only have one minor comment left (we lost a check with the last committed change and I think we should keep it).
Should be good to go after that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for going through the changes!
LGTM 🚀
Hi there, aware this is closed but wanted to ask why Squeeze has to specify only one axis? Is there a plan to make it possible to Squeeze to take > 1 axis? Should I make an issue to track this? In the mean time would it work to split into two layers which each Squeeze a different axis? Sorry I am not an expert in this field
|
Hi @mtobin-tdab 👋 The ONNX spec is quite large so we don't necessarily aim to cover all the edge cases when adding support for an operator. That's why this PR was merged - we just make sure to panic when there's is a case which is not supported yet (such as yours). You can definitely open an issue to ask for support for multiple axes. That's what issues are for (bugs, feature requests, etc.). In the meantime, if you can generate your model with multiple squeeze nodes instead then yes that should work! |
Thank you @laggui ! |
@mtobin-tdab I think |
Ah wait, actually I think I see the issue. It looks like I didn't update the dimension inference, sorry about that! |
Oh I didn't even see the follow-up PR for |
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#1714
Changes
Added the squeeze operation for to
burn-import
Testing