Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Object detection support update for PyTorch 1.6 #6659

Merged
merged 6 commits into from
Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@
__all__ = ["from_pytorch"]


def _is_version_greater_than(ver):
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
return version.parse(torch.__version__) > version.parse(ver)


# List ADT utilities
def _infer_type_with_prelude(val, prelude):
body = _infer_type(val, prelude.mod)
Expand Down Expand Up @@ -413,13 +421,18 @@ def _impl(inputs, input_types):
def _split_with_sizes():
def _impl(inputs, input_types):
data = inputs[0]
sections = inputs[1]
dim = int(inputs[2])

if len(sections) == 1:
# a special case used in torchvision detection models
return _expr.TupleWrapper(_expr.Tuple([data]), 1)

split_index = 0
indices = []
sections = inputs[1]
for i in range(len(sections) - 1):
split_index += sections[i]
index, _ = try_infer_value(sections[i], lambda ret: int(ret))
split_index += index
indices.append(split_index)

return _op.split(data, indices, dim)
Expand Down Expand Up @@ -522,6 +535,9 @@ def _impl(inputs, input_types):

def _where():
def _impl(inputs, input_types):
if len(inputs) == 1:
return _nonzero(False)([inputs[0], True], input_types)

cond = inputs[0]
x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3])
return _op.where(cond, x, y)
Expand Down Expand Up @@ -1865,11 +1881,8 @@ def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)

if _is_quantized_tensor(data, prelude):
import torch
from packaging import version

# Torch version > 1.4 changed upsampling API
if version.parse(torch.__version__) > version.parse("1.4.0"):
if _is_version_greater_than("1.4.0"):
num_inputs = 7
else:
num_inputs = 5
Expand Down Expand Up @@ -2172,9 +2185,11 @@ def _impl(inputs, input_types):
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])

# strided slice to get the dynamic result
return get_relay_op("strided_slice")(
ret = get_relay_op("strided_slice")(
data_slice, begin=_expr.const([0]), end=size, slice_mode="size"
)
# in torchvision, indices from nms are int64
return _op.cast(ret, "int64")

return _impl

Expand Down Expand Up @@ -2266,9 +2281,8 @@ def _impl(inputs, input_types):
ret = _op.transform.argwhere(data)

if is_numpy_style or (len(inputs) > 1 and inputs[1]):
# TODO(kevinthesun): Support this by adding unbind op
# ret = _unbind()([ret, 0], None)
raise RuntimeError("as_tuple is not supported yet for nonzero.")
return _unbind()([ret, 1], None)

return ret

return _impl
Expand Down Expand Up @@ -2335,6 +2349,21 @@ def _impl(inputs, input_types):
return _impl


def _numel():
def _impl(inputs, input_types):
return _op.ndarray_size(inputs[0])

return _impl


def _empty():
def _impl(inputs, input_types):
shape = inputs[0]
return _op.zeros(shape, _convert_dtype_value(inputs[1]))

return _impl


def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
Expand Down Expand Up @@ -2673,6 +2702,10 @@ def _get_convert_map(prelude, default_dtype):
"aten::scatter": _scatter(),
"aten::scalar_tensor": _scalar_tensor(),
"aten::__interpolate": _interpolate(),
"aten::IntImplicit": _identity(),
"aten::tensor": _identity(), # used for example in tensor(1.0)
"aten::numel": _numel(),
"aten::empty": _empty(),
}
return convert_map

Expand All @@ -2681,7 +2714,13 @@ def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch

torch._C._jit_pass_inline(graph)
if _is_version_greater_than("1.5.0"):
# This is required for torchvision detection models from 1.6 above
# It is the same as _jit_pass_inline, except that it has some special
# case behaviors for some ops such as aten::__interpolate()
torch._C._jit_pass_onnx_function_substitution(graph)
else:
torch._C._jit_pass_inline(graph)


def _get_tensor_and_var(torch_tensor, name):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def no_data_full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for zeros and ones.
"""
if len(inputs) == 0:
return [_convert_shape(convert(attrs.shape))]
return [_full_shape_func(inputs[0])]


Expand Down
25 changes: 23 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2865,10 +2865,19 @@ class Where2(Module):
def forward(self, *args):
return torch.where(args[0] > 0, args[0], args[1])

class Where3(Module):
def forward(self, *args):
return torch.where(args[0])[0]

x = torch.rand([3, 2]).float()
verify_model(Where1().float().eval(), input_data=[x])
verify_model(Where1(), input_data=[x])
y = torch.rand([3, 2])
verify_model(Where2().float().eval(), input_data=[x, y])
verify_model(Where2(), input_data=[x, y])

# a single argument variant, equivalent to torch.nonzero(..., as_tuple=True)
inp = torch.rand([10])
inp[3:8] = 0
verify_trace_model(Where3(), [inp], ["llvm"])


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -3152,6 +3161,17 @@ def forward(self, data, index, src):
verify_trace_model(Scatter(1), [in_data, in_index, in_src], ["llvm"])


def test_numel():
class Numel(Module):
def forward(self, data):
return torch.tensor(torch.numel(data))

targets = _get_default_vm_targets()
verify_script_model(Numel(), [(1,)], targets)
verify_script_model(Numel(), [(3, 5)], targets)
verify_script_model(Numel(), [(3, 5, 8)], targets)


def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
Expand Down Expand Up @@ -3455,6 +3475,7 @@ def expected(x_shape, y_shape):
test_forward_unbind()
test_forward_nonzero()
test_forward_scatter()
test_numel()

# Model tests
test_resnet18()
Expand Down