diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 31c78cfdea84..3c61749fc203 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -400,7 +400,13 @@ def slice(self, inputs, input_types): ) # A fast path when slicing is nop. - if target_begin == 0 and target_end >= index_size_limit and stride == 1: + if ( + isinstance(target_begin, int) + and isinstance(target_end, int) + and target_begin == 0 + and target_end >= index_size_limit + and stride == 1 + ): return data # Process begin diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 826edd051544..9f035ade7a21 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1497,6 +1497,10 @@ class SliceWithStride2(torch.nn.Module): def forward(self, x): return x[0::2, 0::2] + x[1::2, 1::2] + class DynamicLengthSlice(torch.nn.Module): + def forward(self, values, length): + return values[0:length] + input_data = torch.rand(input_shape).float() verify_model(Slice1(), input_data=input_data) verify_model(Slice2(), input_data=input_data) @@ -1504,6 +1508,11 @@ def forward(self, x): verify_model(SliceWithStride(), input_data=torch.randn(1, 4)) verify_model(SliceWithStride2(), input_data=torch.randn(4, 4)) + inp = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + slice_len = torch.tensor(2) + targets = ["llvm", "cuda"] + verify_trace_model(DynamicLengthSlice(), [inp, slice_len], targets) + @tvm.testing.uses_gpu def test_forward_narrow():