-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug/Remove Squeeze Panic for Multiple Dimensions #2035
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the issue!
If providing multiple axes is supported we should probably add a test to cover the use case 🤔 What do you think?
I think that would definitely be helpful for catching things like this. I know |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2035 +/- ##
==========================================
- Coverage 84.25% 84.24% -0.01%
==========================================
Files 846 852 +6
Lines 105456 105640 +184
==========================================
+ Hits 88853 89000 +147
- Misses 16603 16640 +37 ☔ View full report in Codecov by Sentry. |
On the ONNX side, we generate the ONNX files to test with this: https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py Right now it only covers a single axis, but we could provide an input with multiple axes to squeeze. For example |
@laggui Hmm, so that behavior doesn't look supported on the torch->ONNX conversion side. When I use your example with Traceback (most recent call last):
File "/burn/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py", line 46, in <module>
main()
File "/burn/crates/burn-import/onnx-tests/tests/squeeze/squeeze.py", line 32, in main
torch.onnx.export(model, test_input, "squeeze_opset16.onnx", verbose=False, opset_version=16)
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
graph = _optimize_graph(
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1940, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.py", line 930, in squeeze
dim = symbolic_helper._get_const(dim, "i", "dim")
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 178, in _get_const
return _parse_arg(value, desc)
File "/burn/myenv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 83, in _parse_arg
return int(node_val)
ValueError: only one element tensors can be converted to Python scalars If we just had an ONNX graph though, I think it might be interpreted correctly on the burn side. |
Oh, well I guess you're right then. If we want to add a test we'll have to construct the graph manually. I could provide a script for that later when I get some time. /edit: this should do the trick I believe import onnx
from onnx import helper, TensorProto
input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 1, 5, 1])
output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5])
squeeze = helper.make_node(op_type="Squeeze", inputs=["input", "axes"], outputs=["output"], name="SqueezeOp")
axes = helper.make_tensor("axes", TensorProto.INT64, dims=[2], vals=[2, 4])
graph = helper.make_graph([squeeze], "SqueezeMultiple", [input], [output], [axes])
opset = helper.make_opsetid("", 13)
m = helper.make_model(graph, opset_imports=[opset])
onnx.checker.check_model(m, full_check=True)
onnx.save(m, "squeeze_multiple.onnx") |
Ok, fingers crossed everything should be working now. Testing the model you constructed also lead me to finding another bug with how output dimensions were calculated, so adding that test was a huge plus. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for adding the multiple dim support and test cases 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Fixes #2033
Changes
Removes panic if user passes more than 1 axes to a Squeeze node in their ONNX graph, should've been part of #1779.