Skip to content

Commit

Permalink
[Relay][ONNX] Batch_matmul to dense optimization (#8440)
Browse files Browse the repository at this point in the history
* [ONNX]Add batch_matmul to dense optimization

* Add extra check to avoid unnecessary reshape

Co-authored-by: Ubuntu <ubuntu@ip-172-31-14-16.us-west-2.compute.internal>
  • Loading branch information
ymwangg and Ubuntu authored Jul 13, 2021
1 parent d67514b commit 136f218
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
27 changes: 18 additions & 9 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,25 +678,34 @@ def _impl_v1(cls, inputs, attr, params):
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:

def flatten_to_3d(x, x_shape):
def flatten_to_nd(x, x_shape, nd=3):
ndims = infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - 2], [ndims]),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, fold_constant(newshape))
return out

# Convert a and b into 3 dimensional tensors.
a = flatten_to_3d(inputs[0], a_shape)
b = flatten_to_3d(inputs[1], b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
b_type = infer_type(inputs[1])
# Convert to dense if the second matrix is 2d and non-dynamic
if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b)
else:
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
output = _op.nn.batch_matmul(a, b)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
Expand Down
7 changes: 4 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,8 +1173,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, target, dev):
verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, targets=[target])


# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed
@tvm.testing.parametrize_targets("llvm")
@tvm.testing.uses_gpu
def test_batch_matmul(target, dev):
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, dev)
verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4), target, dev)
Expand All @@ -1183,6 +1182,8 @@ def test_batch_matmul(target, dev):
verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4), target, dev)
verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4), target, dev)
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4), target, dev)
verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32), target, dev)
verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16), target, dev)


def verify_simple_dynamic_model(a_shape, b_shape, target, dev):
Expand Down Expand Up @@ -1221,7 +1222,6 @@ def verify_model(ex, a_shape, b_shape):
b_anys = [relay.Any()] * len(b_shape)

mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys})

ex = relay.create_executor("vm", mod=mod, device=dev, target=target)
verify_model(ex, a_shape, b_shape)
verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape])
Expand Down Expand Up @@ -4955,3 +4955,4 @@ def test_qlinearadd():
test_reverse_sequence()
test_eyelike()
test_qlinearconv()
test_batch_matmul()

0 comments on commit 136f218

Please sign in to comment.