-
Notifications
You must be signed in to change notification settings - Fork 84
[torchlib] Improve handling of SymInt[] #2522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR improves the handling of SymInt[]
parameters in torch_lib operations by avoiding unnecessary conversion to INT64 tensors and enabling better optimization of static dimensions.
- Updates functions to accept
Sequence[INT64]
instead ofIntType
for shape parameters - Introduces a
merge_dims
helper function to optimize consecutive constant dimensions - Removes unnecessary
Abs
operations inaten_expand
by handling constant -1 values directly
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
onnxscript/function_libs/torch_lib/ops/core.py | Updates multiple torch operations to use new shape handling approach and removes unnecessary type casting |
onnxscript/function_libs/torch_lib/ops/common.py | Adds new merge_dims helper function to optimize consecutive constant dimensions |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2522 +/- ##
==========================================
+ Coverage 69.93% 69.98% +0.04%
==========================================
Files 216 216
Lines 26052 26058 +6
Branches 2616 2618 +2
==========================================
+ Hits 18219 18236 +17
+ Misses 6931 6920 -11
Partials 902 902 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Previously sizes coming in as
SymInt[]
are first concatenated as INT64 then used. This created inefficiencies where we could not process any static dims from the size list and had to treat the whole shape as dynamic. In aten_expand, this meant we needed to addAbs
on the shape.This change updates the functions that take
SymInt[]
such that they are no longer turned into INT64 first. I updated aten_expand to process constant-1
values so anAbs
is not required. I also added a helpermerge_dims
to create constants for consecutive constant dims first before concatinating.