Skip to content

Commit

Permalink
also supported the latest master (1.7)
Browse files Browse the repository at this point in the history
  • Loading branch information
masa authored and masahi committed Oct 10, 2020
1 parent 8d9dd2a commit 017334a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
8 changes: 5 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,9 @@ def _impl(inputs, input_types):

def _where():
def _impl(inputs, input_types):
if len(inputs) == 1:
return _nonzero(False)([inputs[0], True], input_types)

cond = inputs[0]
x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3])
return _op.where(cond, x, y)
Expand Down Expand Up @@ -2278,9 +2281,8 @@ def _impl(inputs, input_types):
ret = _op.transform.argwhere(data)

if is_numpy_style or (len(inputs) > 1 and inputs[1]):
# TODO(kevinthesun): Support this by adding unbind op
# ret = _unbind()([ret, 0], None)
raise RuntimeError("as_tuple is not supported yet for nonzero.")
return _unbind()([ret, 1], None)

return ret

return _impl
Expand Down
13 changes: 11 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2865,10 +2865,19 @@ class Where2(Module):
def forward(self, *args):
return torch.where(args[0] > 0, args[0], args[1])

class Where3(Module):
def forward(self, *args):
return torch.where(args[0])[0]

x = torch.rand([3, 2]).float()
verify_model(Where1().float().eval(), input_data=[x])
verify_model(Where1(), input_data=[x])
y = torch.rand([3, 2])
verify_model(Where2().float().eval(), input_data=[x, y])
verify_model(Where2(), input_data=[x, y])

# a single argument variant, equivalent to torch.nonzero(..., as_tuple=True)
inp = torch.rand([10])
inp[3:8] = 0
verify_trace_model(Where3(), [inp], ["llvm"])


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 017334a

Please sign in to comment.