diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 77c88b465d..50b58a0bdb 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -216,7 +216,7 @@ nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros, boo // Replace all instances of -1, indicating dynamic dimension // with 0, indicating copy the dimension from another tensor // (Generally used for reshape operations) - if (use_zeros && d.d[i] == -1) { + if (use_zeros && d.d[i] == -1 && i < pos) { dims.d[j] = 0; // If zeros already exist in the dimensions (empty tensor), // Replace all instances of 0, indicating empty dimension diff --git a/tests/py/ts/models/test_models.py b/tests/py/ts/models/test_models.py index 3e042fc763..5678e8f648 100644 --- a/tests/py/ts/models/test_models.py +++ b/tests/py/ts/models/test_models.py @@ -1,12 +1,13 @@ +import copy import unittest -import torch_tensorrt as torchtrt +from typing import Dict + +import custom_models as cm +import timm import torch +import torch_tensorrt as torchtrt import torchvision.models as models -import copy -import timm -import custom_models as cm -from typing import Dict -from utils import cosine_similarity, COSINE_THRESHOLD +from utils import COSINE_THRESHOLD, cosine_similarity class TestModels(unittest.TestCase): @@ -152,6 +153,45 @@ def test_resnet18_half(self): msg=f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_aten_unbind_dynamic(self): + class ATenUnbindDynamic(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + x1, x2, x3 = x.unbind(1) + y = torch.cat([x1, x2, x3], dim=0) + return y + + self.model = ATenUnbindDynamic().eval().to("cuda") + self.input = torch.randn((5, 3, 7, 64)).to("cuda") + self.scripted_model = torch.jit.script(self.model) + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=[1, 3, 1, 64], + opt_shape=[5, 3, 32, 64], + max_shape=[10, 3, 64, 64], + dtype=torch.float, + format=torch.contiguous_format, + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + "ir": "ts", + } + + trt_mod = torchtrt.compile(self.scripted_model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"ATen Unbind Dynamic TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + if __name__ == "__main__": unittest.main()