Skip to content

Commit

Permalink
[Torch] Object detection support update for PyTorch 1.6 (#6659)
Browse files Browse the repository at this point in the history
* update split

* fix

* cast nms output to int64

* add more comment and numel test

* fix lint

* also supported the latest master (1.7)

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Oct 12, 2020
1 parent 0cdd285 commit b277f18
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 13 deletions.
61 changes: 50 additions & 11 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@
__all__ = ["from_pytorch"]


def _is_version_greater_than(ver):
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
return version.parse(torch.__version__) > version.parse(ver)


# List ADT utilities
def _infer_type_with_prelude(val, prelude):
body = _infer_type(val, prelude.mod)
Expand Down Expand Up @@ -413,13 +421,18 @@ def _impl(inputs, input_types):
def _split_with_sizes():
def _impl(inputs, input_types):
data = inputs[0]
sections = inputs[1]
dim = int(inputs[2])

if len(sections) == 1:
# a special case used in torchvision detection models
return _expr.TupleWrapper(_expr.Tuple([data]), 1)

split_index = 0
indices = []
sections = inputs[1]
for i in range(len(sections) - 1):
split_index += sections[i]
index, _ = try_infer_value(sections[i], lambda ret: int(ret))
split_index += index
indices.append(split_index)

return _op.split(data, indices, dim)
Expand Down Expand Up @@ -522,6 +535,9 @@ def _impl(inputs, input_types):

def _where():
def _impl(inputs, input_types):
if len(inputs) == 1:
return _nonzero(False)([inputs[0], True], input_types)

cond = inputs[0]
x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3])
return _op.where(cond, x, y)
Expand Down Expand Up @@ -1865,11 +1881,8 @@ def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

if _is_quantized_tensor(data, prelude):
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
if version.parse(torch.__version__) > version.parse("1.4.0"):
if _is_version_greater_than("1.4.0"):
num_inputs = 7
else:
num_inputs = 5
Expand Down Expand Up @@ -2172,9 +2185,11 @@ def _impl(inputs, input_types):
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])

# strided slice to get the dynamic result
return get_relay_op("strided_slice")(
ret = get_relay_op("strided_slice")(
data_slice, begin=_expr.const([0]), end=size, slice_mode="size"
)
# in torchvision, indices from nms are int64
return _op.cast(ret, "int64")

return _impl

Expand Down Expand Up @@ -2266,9 +2281,8 @@ def _impl(inputs, input_types):
ret = _op.transform.argwhere(data)

if is_numpy_style or (len(inputs) > 1 and inputs[1]):
# TODO(kevinthesun): Support this by adding unbind op
# ret = _unbind()([ret, 0], None)
raise RuntimeError("as_tuple is not supported yet for nonzero.")
return _unbind()([ret, 1], None)

return ret

return _impl
Expand Down Expand Up @@ -2335,6 +2349,21 @@ def _impl(inputs, input_types):
return _impl


def _numel():
def _impl(inputs, input_types):
return _op.ndarray_size(inputs[0])

return _impl


def _empty():
def _impl(inputs, input_types):
shape = inputs[0]
return _op.zeros(shape, _convert_dtype_value(inputs[1]))

return _impl


def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
Expand Down Expand Up @@ -2673,6 +2702,10 @@ def _get_convert_map(prelude, default_dtype):
"aten::scatter": _scatter(),
"aten::scalar_tensor": _scalar_tensor(),
"aten::__interpolate": _interpolate(),
"aten::IntImplicit": _identity(),
"aten::tensor": _identity(), # used for example in tensor(1.0)
"aten::numel": _numel(),
"aten::empty": _empty(),
}
return convert_map

Expand All @@ -2681,7 +2714,13 @@ def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch

torch._C._jit_pass_inline(graph)
if _is_version_greater_than("1.5.0"):
# This is required for torchvision detection models from 1.6 above
# It is the same as _jit_pass_inline, except that it has some special
# case behaviors for some ops such as aten::__interpolate()
torch._C._jit_pass_onnx_function_substitution(graph)
else:
torch._C._jit_pass_inline(graph)


def _get_tensor_and_var(torch_tensor, name):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def no_data_full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros and ones.
"""
if len(inputs) == 0:
return [_convert_shape(convert(attrs.shape))]
return [_full_shape_func(inputs[0])]


Expand Down
25 changes: 23 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2865,10 +2865,19 @@ class Where2(Module):
def forward(self, *args):
return torch.where(args[0] > 0, args[0], args[1])

class Where3(Module):
def forward(self, *args):
return torch.where(args[0])[0]

x = torch.rand([3, 2]).float()
verify_model(Where1().float().eval(), input_data=[x])
verify_model(Where1(), input_data=[x])
y = torch.rand([3, 2])
verify_model(Where2().float().eval(), input_data=[x, y])
verify_model(Where2(), input_data=[x, y])

# a single argument variant, equivalent to torch.nonzero(..., as_tuple=True)
inp = torch.rand([10])
inp[3:8] = 0
verify_trace_model(Where3(), [inp], ["llvm"])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -3152,6 +3161,17 @@ def forward(self, data, index, src):
verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"])


def test_numel():
class Numel(Module):
def forward(self, data):
return torch.tensor(torch.numel(data))

targets = _get_default_vm_targets()
verify_script_model(Numel(), [(1,)], targets)
verify_script_model(Numel(), [(3, 5)], targets)
verify_script_model(Numel(), [(3, 5, 8)], targets)


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
Expand Down Expand Up @@ -3455,6 +3475,7 @@ def expected(x_shape, y_shape):
test_forward_unbind()
test_forward_nonzero()
test_forward_scatter()
test_numel()

# Model tests
test_resnet18()
Expand Down

0 comments on commit b277f18

Please sign in to comment.