Skip to content

Commit

Permalink
fix weight shape in torch.mm conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jul 3, 2021
1 parent 7e3f068 commit a1a8fd3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,7 @@ def Float(self, inputs, input_types):
return _op.cast(inputs[0], "float32")

def mm(self, inputs, input_types):
return _op.nn.dense(inputs[0], inputs[1])
return _op.nn.dense(inputs[0], _op.transpose(inputs[1]))

def bitwise_not(self, inputs, input_types):
data = inputs[0]
Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3853,6 +3853,15 @@ def test_fn(x, mask):
verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"])


def test_mm():
def test_fn(x, y):
return torch.mm(x, y)

x = torch.randn((100, 200))
y = torch.randn((200, 100))
verify_trace_model(test_fn, [x, y], ["llvm", "cuda"])


def test_unique():
def test_fn(is_sorted, return_inverse, return_counts):
return lambda x: torch.unique(x, is_sorted, return_inverse, return_counts)
Expand Down Expand Up @@ -4035,6 +4044,7 @@ def test_forward_nll_loss():
test_hard_swish()
test_hard_sigmoid()
test_forward_nll_loss()
test_mm()

# Model tests
test_resnet18()
Expand Down

0 comments on commit a1a8fd3

Please sign in to comment.