-
Notifications
You must be signed in to change notification settings - Fork 364
🐛 [Bug] Support for modules with multiple outputs seems broken in v1.2.0 #1368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
This is a change in TorchScript, not Torch-TensorRT as in the graph you get when you trace has changed. Torch-TensorRT does not have control over this, so while graphs from 1.11 of this form will still work, these graphs will no longer get produced by TorchScript going forward. However, the two graphs are functionally equivalent as when a function in python returns multiple values it actually is returning a tuple of values. The new graph is in fact more descriptive of the operations occurring and more inline with python conventions. In Torch-TensorRT the old system was designed to handle purely functions of the form
model = trt.compile(
model.cuda(),
input_signature=(trt.Input( # Note: It is called input_signature not inputs and takes a tuple of the form you would use to call the original forward function
min_shape=(1, 1, 128, 128),
opt_shape=(4, 1, 256, 256),
max_shape=(8, 1, 512, 512)),),
min_block_size=1
require_full_compilation=True
)
model = trt.compile(
model.cuda(),
inputs=[trt.Input(
min_shape=(1, 1, 128, 128),
opt_shape=(4, 1, 256, 256),
max_shape=(8, 1, 512, 512)),],
min_block_size=1
require_full_compilation=True,
torch_executed_ops=["prim::TupleConstruct"]
) |
Sorry, I should have been more clear. I am holding the version of Torch constant here (1.12.1). The graphs above are the lowered graphs. The input graph from torch does end with a TupleConstruct node:
But in TRT v1.1.1 the TupleConstruct is immediately reduced to two outputs in the lowered graph (while in v1.2.0 it is not reduced).
And in v1.2.0
Unfortunately neither of the code snippets above work with dynamic inputs (works fine w/ a fixed input). Aten throws a fit trying to create a tensor with a negative dimension. |
@narendasan for future reference, is there an open discord / development channel? I'm interesting in contributing and don't want to dirty up the issues page unnecessarily |
We are on the pytorch slack (invite form here https://pytorch.org/resources/) , #jit-be-extension-trt-poc is where we used to discuss development but I can look into getting an official channel created. Also we monitor the Discuss forum for PyTorch and our own Discussions (https://github.com/pytorch/TensorRT/discussions) is where we post designs for new features and related topics. |
@gs-olive can you V2C this with your recent collections changes? |
Just tested this and can confirm the model compiles and runs inference successfully with PR #1599. Though, the outputs are still batched as "one" entry in the TorchScript IR, which is a byproduct of the inserted Batching of Tensor outputs into one object originates from a change in TorchScript and not in Torch-TensorRT, as shown in this snippet of the graph(%self.1 : __torch__.___torch_mangle_9.Net,
%x : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu)):
%g : __torch__.torch.nn.modules.conv.___torch_mangle_8.Conv2d = prim::GetAttr[name="g"](%self.1)
%h : __torch__.torch.nn.modules.conv.___torch_mangle_7.Conv2d = prim::GetAttr[name="h"](%self.1)
%63 : Tensor = prim::CallMethod[name="forward"](%h, %x)
%64 : Tensor = prim::CallMethod[name="forward"](%g, %x)
%52 : (Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu)) = prim::TupleConstruct(%63, %64)
return (%52) |
Still happening in v1.3.0 |
Hello - to make the script succeed in v1.3.0, the argument Regarding the Additionally, despite the indication of a fixed shape in the tensor form model = Net().eval()
model = torch.jit.trace(model, torch.randn(1, 1, 128, 128))
model = trt.compile(
model.cuda(),
inputs=[
trt.Input(min_shape=(1, 1, 128, 128),
opt_shape=(4, 1, 256, 256),
max_shape=(8, 1, 512, 512))
],
min_block_size=1,
require_full_compilation=False
)
x, y = model(torch.randn(1, 1, 128, 128))
x, y = model(torch.randn(4, 1, 256, 256))
x, y = model(torch.randn(7, 1, 300, 300)) Note: Specifying multiple dynamic dimensions is not currently fully supported. |
Bug Description
It appears that modules with multiple outputs no longer compile when using dynamic input shapes in v1.2.0.
The following example works in v1.1.1 but fails in v1.2.0
Fails with error:
In v1.1.1, the graph returns two output tensors - while in v1.2.0 it creates an intermediate node to (%13) and returns a single TupleConstruct output. Unfortunately MarkOutputs in core/conversion/converter.cpp now only gets a single tuple output and throws an error.
Graphs are given below:
v1.1.1
v1.2.0
Expected behavior
A return type of Tuple[Tensor, Tensor] should be treated as two separate outputs - not one.
Environment
Additional context
The text was updated successfully, but these errors were encountered: