Closed
Description
Bug Description
- The graph has a conv node in pytorch and a TensorRT node. The conv node has weight and bias lifted as placeholders. Hence we are seeing this runtime error of mismatch in the number of inputs.
Error message:
_check_input_constraints_for_graph(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/utils.py", line 48, in _check_input_constraints_for_graph
check(
File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/utils.py", line 40, in check
raise RuntimeError(msg)
RuntimeError: Unexpected number of inputs (expected 3, got 1)
- If we unlift these parameters (i.e
conv_weight
andconv_bias
are registered asget_attr
nodes), there's a different errorGraphModule does not contain attribute conv_weight
Reason:
This is because - syntax error occurs in _create_graph_module_for_export and hence the resultinggm
does not have these attributes.
To Reproduce
Install the nightly version of Torch-TRT
pip install --pre torch-tensorrt --extra-index-url https://download.pytorch.org/whl/nightly/cu121
Run the following script to reproduce the error
import torch
import torch_tensorrt
import unittest
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
conv = self.conv(x)
relu = self.relu(conv)
mul = relu * 0.5
return mul
input = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda")
model = MyModule().eval().cuda()
compile_spec = {
"inputs": [
torch_tensorrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": "dynamo",
"min_block_size": 1,
"torch_executed_ops": {"torch.ops.aten.convolution.default"},
}
exp_program = torch_tensorrt.dynamo.trace(model, **compile_spec)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, [input], ir="exported_program")
torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)
Expected behavior
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0):
- PyTorch Version (e.g. 1.0):
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version:
- CUDA version:
- GPU models and configuration:
- Any other relevant information: