Skip to content

Commit

Permalink
support aten::type_as in the pytorch frontend (apache#5787)
Browse files Browse the repository at this point in the history
* support aten::type_as in the pytorch frontend

* use _convert_data_type to convert torch type to tvm type and add more types in the type_as test
  • Loading branch information
randxie authored and Trevor Morris committed Jun 15, 2020
1 parent a9aa8ac commit e21351c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,14 @@ def _impl(inputs, input_types):
return _impl


def _type_as():
def _impl(inputs, input_types):
assert len(inputs) == 2
assert len(input_types) == 2
return _op.cast(inputs[0], _convert_data_type(input_types[1]))
return _impl


def _add(prelude):
# add_ is overloaded for tensor add and list concat
def _impl(inputs, input_types):
Expand Down Expand Up @@ -1953,6 +1961,7 @@ def _get_convert_map(prelude):
"aten::stack" : _tensor_array_stack(prelude),
"aten::__getitem__" : _list_getitem(prelude),
"aten::len" : _list_len(prelude),
"aten::type_as" : _type_as(),
}
return convert_map

Expand Down
37 changes: 37 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from tvm import relay
from tvm.contrib import graph_runtime
from tvm.contrib.nvcc import have_fp16
from tvm.relay.testing.config import ctx_list


Expand Down Expand Up @@ -837,6 +838,41 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(Size1().float().eval(), input_data=input_data)


def test_type_as():
torch.set_grad_enabled(False)
input_shape = [1, 3]

def _create_module(dtype):
class TypeAs(Module):
def forward(self, *args):
expected_type_tensor = torch.zeros(1, 3, dtype=dtype)
return args[0].type_as(expected_type_tensor)

return TypeAs()

input_data = torch.randn(input_shape).float()
verify_model(_create_module(torch.float64), input_data=input_data)
verify_model(_create_module(torch.float32), input_data=input_data)
verify_model(_create_module(torch.int64), input_data=input_data)
verify_model(_create_module(torch.int32), input_data=input_data)
verify_model(_create_module(torch.int16), input_data=input_data)
verify_model(_create_module(torch.int8), input_data=input_data)

if torch.cuda.is_available():
check_fp16 = False
try:
# Only check half precision on supported hardwares.
if have_fp16(tvm.gpu(0).compute_version):
check_fp16 = True
except Exception as e:
# If GPU is not enabled in TVM, skip the fp16 test.
pass

if check_fp16:
verify_model(_create_module(torch.float16), input_data=input_data)


def test_forward_view():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -2575,6 +2611,7 @@ def test_forward_pretrained_bert_base_uncased():
test_upsample()
test_forward_upsample3d()
test_to()
test_type_as()
test_forward_functional_pad()
test_forward_zero_pad2d()
test_forward_constant_pad1d()
Expand Down

0 comments on commit e21351c

Please sign in to comment.