-
Notifications
You must be signed in to change notification settings - Fork 489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/squeeze dims #1779
Feat/squeeze dims #1779
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1779 +/- ##
==========================================
+ Coverage 84.91% 86.41% +1.49%
==========================================
Files 756 737 -19
Lines 87403 85977 -1426
==========================================
+ Hits 74218 74295 +77
+ Misses 13185 11682 -1503 ☔ View full report in Codecov by Sentry. |
if dim_indices.contains(&index) && dim_size == 1 { | ||
check!(TensorCheck::squeeze::<D2>(index, ¤t_dims)); | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nathanielsimard this check doesn't do much of anything right now since the ==1
basically means there's no chance of it erroring out. The squeeze()
function will panic if you ask it to squeeze a dimension that isn't 1. For this, I'm not sure if you want to keep the same behavior as squeeze()
or not. It looks like the ONNX spec says it will raise an error if you try squeezing a non-1 dimension. But PyTorch would just leave that dimension unchanged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion it should behave like our squeeze
, so panic on non-1 dimensions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @agelas
Thanks a lot, it's very well documented and tested. I think we should panic on non-1 dimensions; that would be the only thing to change.
if dim_indices.contains(&index) && dim_size == 1 { | ||
check!(TensorCheck::squeeze::<D2>(index, ¤t_dims)); | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion it should behave like our squeeze
, so panic on non-1 dimensions
} | ||
|
||
/// Test to make sure the function doesn't do anything if a non-singleton dimension is squeezed | ||
#[test] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following my other comment, this test should panic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll make the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#1780
Changes
Adds a new
squeeze_dims
function that's more spec compliant with ONNX and PyTorch. Specifically, users can pass in multiple dimensions, negative indices, or no dimensions at all, in which case the function will squeeze all singleton dimensions. Theburn-import
crate has also been updated to use this.Testing
Added unit tests to
burn-tensor