Skip to content

🐛 [Bug] Support for modules with multiple outputs seems broken in v1.2.0 #1368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
cpjenkins opened this issue Sep 22, 2022 · 8 comments · Fixed by #1599
Closed

🐛 [Bug] Support for modules with multiple outputs seems broken in v1.2.0 #1368

cpjenkins opened this issue Sep 22, 2022 · 8 comments · Fixed by #1599
Assignees
Labels
bug: regression bug Something isn't working component: lowering Issues re: The lowering / preprocessing passes

Comments

@cpjenkins
Copy link

Bug Description

It appears that modules with multiple outputs no longer compile when using dynamic input shapes in v1.2.0.

The following example works in v1.1.1 but fails in v1.2.0

import torch
import torch.nn as nn
import torch_tensorrt as trt

from torch import Tensor
from typing import List, Tuple

trt.logging.set_reportable_log_level(trt.logging.Level.Debug)

class Net(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)

        self.h = nn.Conv2d(1, 4, 3, padding=1)
        self.g = nn.Conv2d(1, 4, 3, padding=1)

    def forward(self, x) -> Tuple[Tensor, Tensor]:
        return self.h(x), self.g(x)

model = Net().eval()
model = torch.jit.trace(model, torch.randn(1, 1, 128, 128))
model = trt.compile(
    model.cuda(),
    inputs=[
        trt.Input(min_shape=(1, 1, 128, 128),
                  opt_shape=(4, 1, 256, 256),
                  max_shape=(8, 1, 512, 512))
    ],
    min_block_size=1
    require_full_compilation=True
)

Fails with error:

RuntimeError: [Error thrown at core/conversion/conversion.cpp:230] Tuple type. Only a single tensor or a TensorList type is supported.

In v1.1.1, the graph returns two output tensors - while in v1.2.0 it creates an intermediate node to (%13) and returns a single TupleConstruct output. Unfortunately MarkOutputs in core/conversion/converter.cpp now only gets a single tuple output and throws an error.

Graphs are given below:

v1.1.1

  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  return (%11, %12)

v1.2.0

  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %13 : (Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu)) = prim::TupleConstruct(%11, %12)
  return (%13)

Expected behavior

A return type of Tuple[Tensor, Tensor] should be treated as two separate outputs - not one.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version: v1.1.1 and v1.2.0
  • PyTorch Version (e.g. 1.0): 1.12.1
  • CPU Architecture: AMD x86_64
  • OS (e.g., Linux): Linux / Ubuntu 22.0.4
  • How you installed PyTorch: source (w/ c++11 abi)
  • Build command you used (if compiling from source): n/a
  • Are you using local sources or building from archives: local source
  • Python version: 3.8
  • CUDA version: 11.7
  • GPU models and configuration: RTX A6000
  • Any other relevant information:

Additional context

@cpjenkins cpjenkins added the bug Something isn't working label Sep 22, 2022
@narendasan narendasan added bug: triaged [not a bug] This isn't a bug and removed bug Something isn't working labels Sep 23, 2022
@narendasan
Copy link
Collaborator

This is a change in TorchScript, not Torch-TensorRT as in the graph you get when you trace has changed. Torch-TensorRT does not have control over this, so while graphs from 1.11 of this form will still work, these graphs will no longer get produced by TorchScript going forward. However, the two graphs are functionally equivalent as when a function in python returns multiple values it actually is returning a tuple of values. The new graph is in fact more descriptive of the operations occurring and more inline with python conventions.

In Torch-TensorRT the old system was designed to handle purely functions of the form f(Tensor,...,Tensor) -> (Tensor,...,Tensor). We are currently overhauling our system to generically handle arbitrary collections of tensors. The preview of this feature is part of 1.2.0. Therefore there are are two options you can use to convert this model:

  1. Use the new experimental collections feature
model = trt.compile(
    model.cuda(),
    input_signature=(trt.Input(  # Note: It is called input_signature not inputs and takes a tuple of the form you would use to call the original forward function
                  min_shape=(1, 1, 128, 128),
                  opt_shape=(4, 1, 256, 256),
                  max_shape=(8, 1, 512, 512)),),
    min_block_size=1
    require_full_compilation=True
)
  1. Or you can manually disable evaluation of prim::TupleConstruct operations which allows you to use the old API but just run that final operation in PyTorch
model = trt.compile(
    model.cuda(),
    inputs=[trt.Input(
                  min_shape=(1, 1, 128, 128),
                  opt_shape=(4, 1, 256, 256),
                  max_shape=(8, 1, 512, 512)),],
    min_block_size=1
    require_full_compilation=True,
    torch_executed_ops=["prim::TupleConstruct"]
)

@cpjenkins
Copy link
Author

Sorry, I should have been more clear. I am holding the version of Torch constant here (1.12.1). The graphs above are the lowered graphs.

The input graph from torch does end with a TupleConstruct node:

GRAPH: [Torch-TensorRT] - Before lowering: graph(%self.1 : __torch__.Net,
      %x : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu)):
  %g : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::GetAttr[name="g"](%self.1)
  %h : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="h"](%self.1)
  %63 : Tensor = prim::CallMethod[name="forward"](%h, %x)
  %64 : Tensor = prim::CallMethod[name="forward"](%g, %x)
  %52 : (Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu)) = prim::TupleConstruct(%63, %64)
  return (%52)

But in TRT v1.1.1 the TupleConstruct is immediately reduced to two outputs in the lowered graph (while in v1.2.0 it is not reduced).

GRAPH: [Torch-TensorRT] - Torch-TensorRT.TorchScript Graph Lowering
GRAPH: [Torch-TensorRT] - Post unpack hardswish: graph(%x : Tensor):
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %4 : int = prim::Constant[value=1](), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %5 : bool = prim::Constant[value=0](), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %6 : bool = prim::Constant[value=1](), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %self.g.weight : Float(4, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %self.g.bias : Float(4, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value= 0.1712 -0.1192 -0.0082 -0.2468 [ CUDAFloatType{4} ]]()
  %self.h.weight : Float(4, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %self.h.bias : Float(4, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value= 0.2731 -0.0596 -0.0197 -0.1654 [ CUDAFloatType{4} ]]()
  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  return (%11, %12)

And in v1.2.0

GRAPH: [Torch-TensorRT] - Torch-TensorRT.TorchScript Graph Lowering
GRAPH: [Torch-TensorRT] - Post unpack hardswish: graph(%x : Tensor):
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %4 : int = prim::Constant[value=1](), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %5 : bool = prim::Constant[value=0](), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %6 : bool = prim::Constant[value=1](), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %self.g.weight : Float(4, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %self.g.bias : Float(4, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value= 0.3308  0.0332 -0.2625 -0.2605 [ CUDAFloatType{4} ]]()
  %self.h.weight : Float(4, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %self.h.bias : Float(4, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=0.01 *  4.9374  8.3168 -11.0279 -8.1161 [ CUDAFloatType{4} ]]()
  %11 : Tensor = aten::_convolution(%x, %self.h.weight, %self.h.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.h # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %12 : Tensor = aten::_convolution(%x, %self.g.weight, %self.g.bias, %3, %3, %3, %5, %2, %4, %5, %5, %6, %6), scope: __module.g # /home/cjenkins/dev/mat/venv/lib/python3.8/site-packages/torch/nn/modules/conv.py:453:0
  %13 : (Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu)) = prim::TupleConstruct(%11, %12)
  return (%13)

Unfortunately neither of the code snippets above work with dynamic inputs (works fine w/ a fixed input). Aten throws a fit trying to create a tensor with a negative dimension.

@narendasan narendasan added bug Something isn't working component: lowering Issues re: The lowering / preprocessing passes and removed bug: triaged [not a bug] This isn't a bug labels Sep 23, 2022
@cpjenkins
Copy link
Author

@narendasan for future reference, is there an open discord / development channel? I'm interesting in contributing and don't want to dirty up the issues page unnecessarily

@narendasan
Copy link
Collaborator

We are on the pytorch slack (invite form here https://pytorch.org/resources/) , #jit-be-extension-trt-poc is where we used to discuss development but I can look into getting an official channel created. Also we monitor the Discuss forum for PyTorch and our own Discussions (https://github.com/pytorch/TensorRT/discussions) is where we post designs for new features and related topics.

@narendasan
Copy link
Collaborator

@gs-olive can you V2C this with your recent collections changes?

@gs-olive
Copy link
Collaborator

gs-olive commented Feb 7, 2023

Just tested this and can confirm the model compiles and runs inference successfully with PR #1599. Though, the outputs are still batched as "one" entry in the TorchScript IR, which is a byproduct of the inserted prim::TupleConstruct operation. The resultant output is still (Tensor, Tensor), as expected, despite the fact that the IR shows the output as a single value.

Batching of Tensor outputs into one object originates from a change in TorchScript and not in Torch-TensorRT, as shown in this snippet of the torch.jit.trace output, (using Torch 2.0.0.dev20230128+cu117). which batches the outputs before returning them:

graph(%self.1 : __torch__.___torch_mangle_9.Net,
      %x : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu)):
  %g : __torch__.torch.nn.modules.conv.___torch_mangle_8.Conv2d = prim::GetAttr[name="g"](%self.1)
  %h : __torch__.torch.nn.modules.conv.___torch_mangle_7.Conv2d = prim::GetAttr[name="h"](%self.1)
  %63 : Tensor = prim::CallMethod[name="forward"](%h, %x)
  %64 : Tensor = prim::CallMethod[name="forward"](%g, %x)
  %52 : (Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu)) = prim::TupleConstruct(%63, %64)
  return (%52)

@Charlyo
Copy link

Charlyo commented Feb 17, 2023

Still happening in v1.3.0

@gs-olive
Copy link
Collaborator

Hello - to make the script succeed in v1.3.0, the argument require_full_compilation=False would be needed since full compilation with collection-based outputs is not supported in that version. PR #1599 will address this issue, and the script runs to completion as-is is with that PR.

Regarding the prim::TupleConstruct node, the root cause of the fixed-shape typing displayed in the TorchScript is the use of torch.jit.trace as opposed to torch.jit.script. If we use script in this case, no shape annotations are displayed in the resulting graph.

Additionally, despite the indication of a fixed shape in the tensor form Float(1, 4, 128, 128, strides=[65536, 16384, 128, 1], requires_grad=0, device=cpu), inputs having a dynamic batch still successfully function for inference. I believe that these shapes are primarily for annotation purposes from the torch.jit.trace call and may not reflect the actual shape of the Tensor object. For instance, the following are functioning on my machine running v1.3.0:

model = Net().eval()
model = torch.jit.trace(model, torch.randn(1, 1, 128, 128))
model = trt.compile(
    model.cuda(),
    inputs=[
        trt.Input(min_shape=(1, 1, 128, 128),
                  opt_shape=(4, 1, 256, 256),
                  max_shape=(8, 1, 512, 512))
    ],
    min_block_size=1,
    require_full_compilation=False
)

x, y = model(torch.randn(1, 1, 128, 128))
x, y = model(torch.randn(4, 1, 256, 256))
x, y = model(torch.randn(7, 1, 300, 300))

Note: Specifying multiple dynamic dimensions is not currently fully supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug: regression bug Something isn't working component: lowering Issues re: The lowering / preprocessing passes
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants