-
Notifications
You must be signed in to change notification settings - Fork 364
fix: Error with aten.view
across Tensor memory
#2464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Address error where `aten.view` is called on TRT output Tensors, which can be in a different memory format than Torch expects - Specifically, TRT can modify tensor memory to optimize certain layers, but Torch's view operator depends on specific configurations which can be violated at runtime (but not at compile time, since Torch itself would run these configurations correctly) - Add a custom lowering pass to replace `view` with `reshape`, avoiding this issue. Reshape will make a copy of the underlying Tensor if necessary - Torch-TRT's `aten.view` implementation is the same as that for `aten.reshape`, and they share a schema so no changes are needed on the converter side - Add test case to validate new lowering pass
9d3248a
to
7f88494
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks good to me. Just curious, is this problem that Pytorch would do something after converting ops or after creating TRT Engines (because I think aten.reshape.default
and aten.view.default
are exactly same)?
Thanks for the review @zewenli98 - the main issue here is that when an RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. In that sense, this lowering pass is much more for PyTorch (and fallback graphs) than it is for TensorRT |
Aha I see, so does this mean we can omit |
@zewenli98 - I believe the cases that |
Description
aten.view
is called on TRT output Tensors, which can be in a different memory format than Torch expectsview
withreshape
, avoiding this issue. Reshape will make a copy of the underlying Tensor if necessaryaten.view
implementation is the same as that foraten.reshape
, and they share a schema so no changes are needed on the converter sideError:
Addresses #2415
Type of change
Checklist: