Skip to content

Commit

Permalink
[Torch] Fix converting torch slice op with dynamic slice length (apac…
Browse files Browse the repository at this point in the history
…he#7549)

* Fix converting torch slice op with dynamic slice length

* use isinstance

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
2 people authored and trevor-m committed May 11, 2021
1 parent 0468456 commit 75bcd37
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,13 @@ def slice(self, inputs, input_types):
)

# A fast path when slicing is nop.
if target_begin == 0 and target_end >= index_size_limit and stride == 1:
if (
isinstance(target_begin, int)
and isinstance(target_end, int)
and target_begin == 0
and target_end >= index_size_limit
and stride == 1
):
return data

# Process begin
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,13 +1497,22 @@ class SliceWithStride2(torch.nn.Module):
def forward(self, x):
return x[0::2, 0::2] + x[1::2, 1::2]

class DynamicLengthSlice(torch.nn.Module):
def forward(self, values, length):
return values[0:length]

input_data = torch.rand(input_shape).float()
verify_model(Slice1(), input_data=input_data)
verify_model(Slice2(), input_data=input_data)
verify_model(Slice3(), input_data=input_data)
verify_model(SliceWithStride(), input_data=torch.randn(1, 4))
verify_model(SliceWithStride2(), input_data=torch.randn(4, 4))

inp = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
slice_len = torch.tensor(2)
targets = ["llvm", "cuda"]
verify_trace_model(DynamicLengthSlice(), [inp, slice_len], targets)


@tvm.testing.uses_gpu
def test_forward_narrow():
Expand Down

0 comments on commit 75bcd37

Please sign in to comment.