Skip to content

Commit

Permalink
Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair (#…
Browse files Browse the repository at this point in the history
…7845)

* [AutoScheduler] Fix incorrectly array context device and hide info at the beginning

* Lint fix

* Lint fix

* update repo

* Fix Pytorch matmul conversion when given (2-dim, N-dim) input pair

* update measure.py

* Lint fix

* fix bug && add ut for pytorch matmul

* update ut

* Lint fix

* update commit

* Lint fix
  • Loading branch information
yuchaoli authored Apr 15, 2021
1 parent 36b7dd9 commit b24fbe7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
24 changes: 19 additions & 5 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def matmul(self, inputs, input_types):
b_shape = self.infer_shape_with_prelude(inputs_1)

# When performing a batch matmul, we need to properly handle N-dim shapes.
if len(a_shape) > 2 or len(b_shape) > 2:
if len(a_shape) > 2 and len(b_shape) > 2:
# Convert a into a 3 dimensional tensors.
need_reshape_output = False
if len(a_shape) != 3:
Expand All @@ -1606,18 +1606,32 @@ def matmul(self, inputs, input_types):
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
return output
elif len(a_shape) > 2:
inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])

# Otherwise a simple dense op will get the job done.
if len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)
else:
if len(b_shape) > 2:
trans_axes = list(range(len(b_shape)))
trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
elif len(b_shape) == 1:
input_1 = _op.expand_dims(inputs_1, 0, 1)

out = _op.nn.dense(inputs_0, input_1)

if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])

# Reshape output into a N dimensional tensor when a or b dim > 2
if len(a_shape) > 2:
out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
elif len(b_shape) > 2:
out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
out = _op.reshape(
_op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
)

return out

def expand(self, inputs, input_types):
Expand Down
31 changes: 27 additions & 4 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
return est


def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
def verify_model(
model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]
):
"""Assert that the output of a compiled model matches with that of its
baseline."""
if isinstance(model_name, str):
Expand Down Expand Up @@ -219,6 +221,20 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at

assert_shapes_match(baseline_output, compiled_output)
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)

if expected_ops:

def visit(op):
if isinstance(op, tvm.ir.op.Op):
if op.name in expected_ops:
expected_ops.remove(op.name)

tvm.relay.analysis.post_order_visit(mod["main"].body, visit)

if expected_ops:
msg = "TVM Relay do not contain expected ops {}"
raise AssertionError(msg.format(expected_ops))

del model_name
del baseline_model
torch.cuda.empty_cache()
Expand Down Expand Up @@ -3304,17 +3320,24 @@ def forward(self, *args):
# matrix x matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(4, 10)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(
MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
)

# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# broadcasted matrix x batched matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(3, 4, 5)
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])

# batched matrix x batched matrix
tensor1 = torch.randn(1, 12, 14, 64)
Expand Down

0 comments on commit b24fbe7

Please sign in to comment.