Skip to content

fix: Address runtimes with 0D inputs #2188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
TORCHTRT_CHECK(
inputs[i].dtype() == expected_type,
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);
Expand Down
18 changes: 9 additions & 9 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ class _ShapeMode(Enum):
shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
dtype: _enums.dtype = ( # type: ignore[name-defined]
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = ( # type: ignore[name-defined]
format: _enums.TensorFormat = (
_enums.TensorFormat.contiguous
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)

Expand Down Expand Up @@ -208,7 +208,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
return False

@staticmethod
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
def _parse_dtype(dtype: Any) -> _enums.dtype:
if isinstance(dtype, torch.dtype):
if dtype == torch.long:
return _enums.dtype.long
Expand Down Expand Up @@ -236,7 +236,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
)

@staticmethod
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
if dtype == _enums.dtype.long:
return torch.long
elif dtype == _enums.dtype.int32:
Expand All @@ -255,7 +255,7 @@ def is_trt_dtype(self) -> bool:
return bool(self.dtype != _enums.dtype.long)

@staticmethod
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
def _parse_format(format: Any) -> _enums.TensorFormat:
if isinstance(format, torch.memory_format):
if format == torch.contiguous_format:
return _enums.TensorFormat.contiguous
Expand Down Expand Up @@ -337,18 +337,18 @@ def from_tensor(
A Input object.
"""
if not (
t.is_contiguous(memory_format=torch.contiguous_format)
disable_memory_format_check
or t.is_contiguous(memory_format=torch.contiguous_format)
Comment on lines +340 to +341
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certain 0D tensors do not have the attribute is_contiguous, specifically when fake-ified. This logic reordering preserves the original conditional expression but ensures we do not access that attribute when disable_memory_format_check is set.

or t.is_contiguous(memory_format=torch.channels_last)
or disable_memory_format_check
):
raise ValueError(
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
)
frmt = (
torch.contiguous_format
if (
t.is_contiguous(memory_format=torch.contiguous_format)
or disable_memory_format_check
disable_memory_format_check
or t.is_contiguous(memory_format=torch.contiguous_format)
)
else torch.channels_last
)
Expand Down
42 changes: 40 additions & 2 deletions tests/py/dynamo/backend/test_specialized_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests
from utils import lower_graph_testing


class TestFakeTensors(TestCase):
Expand Down Expand Up @@ -118,5 +118,43 @@ def forward(self, x):
torch._dynamo.reset()


class Test0DTensors(TestCase):
def test_0D_input(self):
class Tensor0DInput(torch.nn.Module):
def forward(self, x):
return x * 7

inputs = [
torch.tensor(
3,
)
.cuda()
.int(),
]

fx_graph = torch.fx.symbolic_trace(Tensor0DInput())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
msg=f"0D-Tensor TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()