From a1a8fd313c999060db675848f8b3de3e1c78e468 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Jul 2021 22:59:14 +0900 Subject: [PATCH 1/3] fix weight shape in torch.mm conversion --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 00fa9f597d06..7c40496e1088 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1878,7 +1878,7 @@ def Float(self, inputs, input_types): return _op.cast(inputs[0], "float32") def mm(self, inputs, input_types): - return _op.nn.dense(inputs[0], inputs[1]) + return _op.nn.dense(inputs[0], _op.transpose(inputs[1])) def bitwise_not(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2ec281094080..022d28134b32 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3853,6 +3853,15 @@ def test_fn(x, mask): verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"]) +def test_mm(): + def test_fn(x, y): + return torch.mm(x, y) + + x = torch.randn((100, 200)) + y = torch.randn((200, 100)) + verify_trace_model(test_fn, [x, y], ["llvm", "cuda"]) + + def test_unique(): def test_fn(is_sorted, return_inverse, return_counts): return lambda x: torch.unique(x, is_sorted, return_inverse, return_counts) @@ -4035,6 +4044,7 @@ def test_forward_nll_loss(): test_hard_swish() test_hard_sigmoid() test_forward_nll_loss() + test_mm() # Model tests test_resnet18() From 64a2919d511487c8587e4de8cdce3a3225a24b1e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Jul 2021 23:27:43 +0900 Subject: [PATCH 2/3] Revert "fix weight shape in torch.mm conversion" This reverts commit a1a8fd313c999060db675848f8b3de3e1c78e468. --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7c40496e1088..00fa9f597d06 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1878,7 +1878,7 @@ def Float(self, inputs, input_types): return _op.cast(inputs[0], "float32") def mm(self, inputs, input_types): - return _op.nn.dense(inputs[0], _op.transpose(inputs[1])) + return _op.nn.dense(inputs[0], inputs[1]) def bitwise_not(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 022d28134b32..2ec281094080 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3853,15 +3853,6 @@ def test_fn(x, mask): verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"]) -def test_mm(): - def test_fn(x, y): - return torch.mm(x, y) - - x = torch.randn((100, 200)) - y = torch.randn((200, 100)) - verify_trace_model(test_fn, [x, y], ["llvm", "cuda"]) - - def test_unique(): def test_fn(is_sorted, return_inverse, return_counts): return lambda x: torch.unique(x, is_sorted, return_inverse, return_counts) @@ -4044,7 +4035,6 @@ def test_forward_nll_loss(): test_hard_swish() test_hard_sigmoid() test_forward_nll_loss() - test_mm() # Model tests test_resnet18() From a76abdc1f9ec128db54ac7c6a1e86308d486707d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Jul 2021 23:28:14 +0900 Subject: [PATCH 3/3] [Torch] remove unused conversion --- python/tvm/relay/frontend/pytorch.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 00fa9f597d06..118af5a30620 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1877,9 +1877,6 @@ def Float(self, inputs, input_types): assert len(inputs) == 1 return _op.cast(inputs[0], "float32") - def mm(self, inputs, input_types): - return _op.nn.dense(inputs[0], inputs[1]) - def bitwise_not(self, inputs, input_types): data = inputs[0] # The input tensor must be of integral or Boolean types.